Skip to content

Commit

Permalink
refactor: move get_usage to provider trait (#506)
Browse files Browse the repository at this point in the history
Co-authored-by: Zaki Ali <[email protected]>
  • Loading branch information
lifeizhou-ap and zakiali authored Dec 22, 2024
1 parent 37e67e1 commit 1b0b70c
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 80 deletions.
20 changes: 18 additions & 2 deletions crates/goose-server/src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ pub enum ProviderSettings {
temperature: Option<f32>,
#[serde(default)]
max_tokens: Option<i32>,
#[serde(default)]
context_limit: Option<usize>,
#[serde(default)]
estimate_factor: Option<f32>,
},
Groq {
#[serde(default = "default_groq_host")]
Expand All @@ -99,6 +103,10 @@ pub enum ProviderSettings {
temperature: Option<f32>,
#[serde(default)]
max_tokens: Option<i32>,
#[serde(default)]
context_limit: Option<usize>,
#[serde(default)]
estimate_factor: Option<f32>,
},
}

Expand Down Expand Up @@ -174,25 +182,33 @@ impl ProviderSettings {
model,
temperature,
max_tokens,
context_limit,
estimate_factor,
} => ProviderConfig::Google(GoogleProviderConfig {
host,
api_key,
model: ModelConfig::new(model)
.with_temperature(temperature)
.with_max_tokens(max_tokens),
.with_max_tokens(max_tokens)
.with_context_limit(context_limit)
.with_estimate_factor(estimate_factor),
}),
ProviderSettings::Groq {
host,
api_key,
model,
temperature,
max_tokens,
context_limit,
estimate_factor,
} => ProviderConfig::Groq(GroqProviderConfig {
host,
api_key,
model: ModelConfig::new(model)
.with_temperature(temperature)
.with_max_tokens(max_tokens),
.with_max_tokens(max_tokens)
.with_context_limit(context_limit)
.with_estimate_factor(estimate_factor),
}),
}
}
Expand Down
4 changes: 4 additions & 0 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,10 @@ mod tests {
fn get_model_config(&self) -> &ModelConfig {
&self.model_config
}

fn get_usage(&self, data: &Value) -> anyhow::Result<Usage> {
Ok(Usage::new(None, None, None))
}
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod utils;

pub mod google;
pub mod groq;

#[cfg(test)]
pub mod mock;
#[cfg(test)]
Expand Down
52 changes: 26 additions & 26 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,6 @@ impl AnthropicProvider {
Ok(Self { client, config })
}

fn get_usage(data: &Value) -> Result<Usage> {
// Extract usage data if available
if let Some(usage) = data.get("usage") {
let input_tokens = usage
.get("input_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as i32);
let output_tokens = usage
.get("output_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as i32);
let total_tokens = match (input_tokens, output_tokens) {
(Some(i), Some(o)) => Some(i + o),
_ => None,
};

Ok(Usage::new(input_tokens, output_tokens, total_tokens))
} else {
// If no usage data, return None for all values
Ok(Usage::new(None, None, None))
}
}

fn tools_to_anthropic_spec(tools: &[Tool]) -> Vec<Value> {
let mut unique_tools = HashSet::new();
let mut tool_specs = Vec::new();
Expand Down Expand Up @@ -212,6 +189,10 @@ impl AnthropicProvider {

#[async_trait]
impl Provider for AnthropicProvider {
fn get_model_config(&self) -> &ModelConfig {
self.config.model_config()
}

async fn complete(
&self,
system: &str,
Expand Down Expand Up @@ -261,15 +242,34 @@ impl Provider for AnthropicProvider {

// Parse response
let message = Self::parse_anthropic_response(response.clone())?;
let usage = Self::get_usage(&response)?;
let usage = self.get_usage(&response)?;
let model = get_model(&response);
let cost = cost(&usage, &model_pricing_for(&model));

Ok((message, ProviderUsage::new(model, usage, cost)))
}

fn get_model_config(&self) -> &ModelConfig {
self.config.model_config()
fn get_usage(&self, data: &Value) -> Result<Usage> {
// Extract usage data if available
if let Some(usage) = data.get("usage") {
let input_tokens = usage
.get("input_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as i32);
let output_tokens = usage
.get("output_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as i32);
let total_tokens = match (input_tokens, output_tokens) {
(Some(i), Some(o)) => Some(i + o),
_ => None,
};

Ok(Usage::new(input_tokens, output_tokens, total_tokens))
} else {
// If no usage data, return None for all values
Ok(Usage::new(None, None, None))
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions crates/goose/src/providers/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ impl Usage {
}

use async_trait::async_trait;
use serde_json::Value;

/// Base trait for AI providers (OpenAI, Anthropic, etc)
#[async_trait]
Expand All @@ -70,6 +71,8 @@ pub trait Provider: Send + Sync {
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage)>;

fn get_usage(&self, data: &Value) -> Result<Usage>;
}

#[cfg(test)]
Expand Down
21 changes: 13 additions & 8 deletions crates/goose/src/providers/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ impl DatabricksProvider {
}
}

fn get_usage(data: &Value) -> Result<Usage> {
get_openai_usage(data)
}

async fn post(&self, payload: Value) -> Result<Value> {
let url = format!(
"{}/serving-endpoints/{}/invocations",
Expand All @@ -74,14 +70,23 @@ impl DatabricksProvider {

#[async_trait]
impl Provider for DatabricksProvider {
fn get_model_config(&self) -> &ModelConfig {
self.config.model_config()
}

async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage)> {
// Prepare messages and tools
let messages_spec = messages_to_openai_spec(messages, &self.config.image_format, false);
let concat_tool_response_contents = false;
let messages_spec = messages_to_openai_spec(
messages,
&self.config.image_format,
concat_tool_response_contents,
);
let tools_spec = if !tools.is_empty() {
tools_to_openai_spec(tools)?
} else {
Expand Down Expand Up @@ -131,15 +136,15 @@ impl Provider for DatabricksProvider {

// Parse response
let message = openai_response_to_message(response.clone())?;
let usage = Self::get_usage(&response)?;
let usage = self.get_usage(&response)?;
let model = get_model(&response);
let cost = cost(&usage, &model_pricing_for(&model));

Ok((message, ProviderUsage::new(model, usage, cost)))
}

fn get_model_config(&self) -> &ModelConfig {
self.config.model_config()
fn get_usage(&self, data: &Value) -> Result<Usage> {
get_openai_usage(data)
}
}

Expand Down
42 changes: 21 additions & 21 deletions crates/goose/src/providers/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,6 @@ impl GoogleProvider {
Ok(Self { client, config })
}

fn get_usage(&self, data: &Value) -> anyhow::Result<Usage> {
if let Some(usage_meta_data) = data.get("usageMetadata") {
let input_tokens = usage_meta_data
.get("promptTokenCount")
.and_then(|v| v.as_u64())
.map(|v| v as i32);
let output_tokens = usage_meta_data
.get("candidatesTokenCount")
.and_then(|v| v.as_u64())
.map(|v| v as i32);
let total_tokens = usage_meta_data
.get("totalTokenCount")
.and_then(|v| v.as_u64())
.map(|v| v as i32);
Ok(Usage::new(input_tokens, output_tokens, total_tokens))
} else {
// If no usage data, return None for all values
Ok(Usage::new(None, None, None))
}
}

async fn post(&self, payload: Value) -> anyhow::Result<Value> {
let url = format!(
"{}/v1beta/models/{}:generateContent?key={}",
Expand Down Expand Up @@ -343,6 +322,27 @@ impl Provider for GoogleProvider {
let provider_usage = ProviderUsage::new(model, usage, None);
Ok((message, provider_usage))
}

fn get_usage(&self, data: &Value) -> anyhow::Result<Usage> {
if let Some(usage_meta_data) = data.get("usageMetadata") {
let input_tokens = usage_meta_data
.get("promptTokenCount")
.and_then(|v| v.as_u64())
.map(|v| v as i32);
let output_tokens = usage_meta_data
.get("candidatesTokenCount")
.and_then(|v| v.as_u64())
.map(|v| v as i32);
let total_tokens = usage_meta_data
.get("totalTokenCount")
.and_then(|v| v.as_u64())
.map(|v| v as i32);
Ok(Usage::new(input_tokens, output_tokens, total_tokens))
} else {
// If no usage data, return None for all values
Ok(Usage::new(None, None, None))
}
}
}

#[cfg(test)] // Only compiles this module when running tests
Expand Down
21 changes: 13 additions & 8 deletions crates/goose/src/providers/groq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use crate::message::Message;
use crate::providers::base::{Provider, ProviderUsage, Usage};
use crate::providers::configs::{GroqProviderConfig, ModelConfig, ProviderModelConfig};
use crate::providers::openai_utils::{
create_openai_request_payload, get_openai_usage, openai_response_to_message,
create_openai_request_payload_with_concat_response_content, get_openai_usage,
openai_response_to_message,
};
use crate::providers::utils::{get_model, handle_response};
use async_trait::async_trait;
Expand All @@ -28,10 +29,6 @@ impl GroqProvider {
Ok(Self { client, config })
}

fn get_usage(data: &Value) -> anyhow::Result<Usage> {
get_openai_usage(data)
}

async fn post(&self, payload: Value) -> anyhow::Result<Value> {
let url = format!(
"{}/openai/v1/chat/completions",
Expand Down Expand Up @@ -61,17 +58,25 @@ impl Provider for GroqProvider {
messages: &[Message],
tools: &[Tool],
) -> anyhow::Result<(Message, ProviderUsage)> {
let payload =
create_openai_request_payload(&self.config.model, system, messages, tools, true)?;
let payload = create_openai_request_payload_with_concat_response_content(
&self.config.model,
system,
messages,
tools,
)?;

let response = self.post(payload).await?;

let message = openai_response_to_message(response.clone())?;
let usage = Self::get_usage(&response)?;
let usage = self.get_usage(&response)?;
let model = get_model(&response);

Ok((message, ProviderUsage::new(model, usage, None)))
}

fn get_usage(&self, data: &Value) -> anyhow::Result<Usage> {
get_openai_usage(data)
}
}

#[cfg(test)]
Expand Down
5 changes: 5 additions & 0 deletions crates/goose/src/providers/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use anyhow::Result;
use async_trait::async_trait;
use mcp_core::tool::Tool;
use rust_decimal_macros::dec;
use serde_json::Value;
use std::sync::Arc;
use std::sync::Mutex;

Expand Down Expand Up @@ -60,4 +61,8 @@ impl Provider for MockProvider {
))
}
}

fn get_usage(&self, data: &Value) -> Result<Usage> {
Ok(Usage::new(None, None, None))
}
}
13 changes: 6 additions & 7 deletions crates/goose/src/providers/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ impl OllamaProvider {
Ok(Self { client, config })
}

fn get_usage(data: &Value) -> Result<Usage> {
get_openai_usage(data)
}

async fn post(&self, payload: Value) -> Result<Value> {
let url = format!(
"{}/v1/chat/completions",
Expand All @@ -57,19 +53,22 @@ impl Provider for OllamaProvider {
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage)> {
let payload =
create_openai_request_payload(&self.config.model, system, messages, tools, false)?;
let payload = create_openai_request_payload(&self.config.model, system, messages, tools)?;

let response = self.post(payload).await?;

// Parse response
let message = openai_response_to_message(response.clone())?;
let usage = Self::get_usage(&response)?;
let usage = self.get_usage(&response)?;
let model = get_model(&response);
let cost = None;

Ok((message, ProviderUsage::new(model, usage, cost)))
}

fn get_usage(&self, data: &Value) -> Result<Usage> {
get_openai_usage(data)
}
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit 1b0b70c

Please sign in to comment.