diff --git a/core/src/blocks/chat.rs b/core/src/blocks/chat.rs index 582c0c46feb7..c0c3bd9b1026 100644 --- a/core/src/blocks/chat.rs +++ b/core/src/blocks/chat.rs @@ -2,9 +2,8 @@ use crate::blocks::block::{ parse_pair, replace_variables_in_string, Block, BlockResult, BlockType, Env, }; use crate::deno::script::Script; -use crate::providers::llm::{ - ChatFunction, ChatFunctionCall, ChatMessage, ChatMessageRole, LLMChatRequest, -}; +use crate::providers::chat_messages::{AssistantChatMessage, ChatMessage, SystemChatMessage}; +use crate::providers::llm::{ChatFunction, ChatMessageRole, LLMChatRequest}; use crate::providers::provider::ProviderID; use crate::Rule; use anyhow::{anyhow, Result}; @@ -129,7 +128,7 @@ impl Chat { #[derive(Debug, Serialize, PartialEq)] struct ChatValue { - message: ChatMessage, + message: AssistantChatMessage, } #[async_trait] @@ -316,93 +315,9 @@ impl Block for Chat { expecting an array of objects with fields `role`, possibly `name`, \ and `content` or `function_call(s)`."; - let mut messages = match messages_value { - Value::Array(a) => a - .into_iter() - .map(|v| match v { - Value::Object(o) => { - match ( - o.get("role"), - o.get("content"), - o.get("function_call"), - o.get("function_calls"), - ) { - (Some(Value::String(r)), Some(Value::String(c)), None, None) => { - Ok(ChatMessage { - role: ChatMessageRole::from_str(r)?, - name: match o.get("name") { - Some(Value::String(n)) => Some(n.clone()), - _ => None, - }, - content: Some(c.clone()), - function_call: None, - function_calls: None, - function_call_id: match o.get("function_call_id") { - Some(Value::String(fcid)) => Some(fcid.to_string()), - _ => None, - }, - }) - } - (Some(Value::String(r)), None, Some(Value::Object(fc)), None) => { - // parse function call into ChatFunctionCall - match (fc.get("name"), fc.get("arguments"), fc.get("id")) { - ( - Some(Value::String(n)), - Some(Value::String(a)), - Some(Value::String(id)), - ) => Ok(ChatMessage { - role: ChatMessageRole::from_str(r)?, - name: match o.get("name") { - Some(Value::String(n)) => Some(n.clone()), - _ => None, - }, - content: None, - function_call: Some(ChatFunctionCall { - id: id.clone(), - name: n.clone(), - arguments: a.clone(), - }), - function_calls: None, - function_call_id: None, - }), - _ => Err(anyhow!(MESSAGES_CODE_OUTPUT)), - } - } - (Some(Value::String(r)), None, None, Some(Value::Array(fcs))) => { - let function_calls = fcs - .into_iter() - .map(|fc| { - match (fc.get("name"), fc.get("arguments"), fc.get("id")) { - ( - Some(Value::String(n)), - Some(Value::String(a)), - Some(Value::String(id)), - ) => Ok(ChatFunctionCall { - id: id.clone(), - name: n.clone(), - arguments: a.clone(), - }), - _ => Err(anyhow!(MESSAGES_CODE_OUTPUT)), - } - }) - .collect::, _>>()?; - - Ok(ChatMessage { - role: ChatMessageRole::from_str(r)?, - name: None, - content: None, - function_call: None, - function_calls: Some(function_calls), - function_call_id: None, - }) - } - _ => Err(anyhow!(MESSAGES_CODE_OUTPUT)), - } - } - _ => Err(anyhow!(MESSAGES_CODE_OUTPUT)), - }) - .collect::>>()?, - _ => Err(anyhow!(MESSAGES_CODE_OUTPUT))?, + let mut messages: Vec = match messages_value { + Value::Array(a) => serde_json::from_value(Value::Array(a))?, + _ => return Err(anyhow!(MESSAGES_CODE_OUTPUT)), }; // Process functions. @@ -481,14 +396,10 @@ impl Block for Chat { if i.len() > 0 { messages.insert( 0, - ChatMessage { + ChatMessage::System(SystemChatMessage { role: ChatMessageRole::System, - name: None, - content: Some(i), - function_call: None, - function_calls: None, - function_call_id: None, - }, + content: i, + }), ); } diff --git a/core/src/blocks/llm.rs b/core/src/blocks/llm.rs index 214dfe454904..b75fc828f8be 100644 --- a/core/src/blocks/llm.rs +++ b/core/src/blocks/llm.rs @@ -1,7 +1,8 @@ use crate::blocks::block::{ find_variables, parse_pair, replace_variables_in_string, Block, BlockResult, BlockType, Env, }; -use crate::providers::llm::{ChatMessage, ChatMessageRole, LLMChatRequest, LLMRequest, Tokens}; +use crate::providers::chat_messages::{ChatMessage, ContentBlock, UserChatMessage}; +use crate::providers::llm::{ChatMessageRole, LLMChatRequest, LLMRequest, Tokens}; use crate::providers::provider::ProviderID; use crate::Rule; use anyhow::{anyhow, Result}; @@ -422,14 +423,11 @@ impl Block for LLM { { true => { let prompt = self.prompt(env)?; - let messages = vec![ChatMessage { + let messages = vec![ChatMessage::User(UserChatMessage { role: ChatMessageRole::User, name: None, - content: Some(prompt.clone()), - function_call: None, - function_calls: None, - function_call_id: None, - }]; + content: ContentBlock::Text(prompt.clone()), + })]; let request = LLMChatRequest::new( provider_id, diff --git a/core/src/lib.rs b/core/src/lib.rs index 1babda7a5eee..f29494c3ef9b 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -31,6 +31,7 @@ pub mod providers { pub mod mistral; pub mod openai; + pub mod chat_messages; pub mod provider; pub mod tiktoken { pub mod tiktoken; diff --git a/core/src/providers/anthropic.rs b/core/src/providers/anthropic.rs index 0c5479cc5fdb..3b56085c4faa 100644 --- a/core/src/providers/anthropic.rs +++ b/core/src/providers/anthropic.rs @@ -1,6 +1,6 @@ use crate::providers::embedder::{Embedder, EmbedderVector}; use crate::providers::llm::{ - ChatMessage, ChatMessageRole, LLMChatGeneration, LLMGeneration, LLMTokenUsage, Tokens, LLM, + ChatMessageRole, LLMChatGeneration, LLMGeneration, LLMTokenUsage, Tokens, LLM, }; use crate::providers::provider::{ModelError, ModelErrorRetryOptions, Provider, ProviderID}; use crate::providers::tiktoken::tiktoken::anthropic_base_singleton; @@ -22,6 +22,7 @@ use std::str::FromStr; use std::time::Duration; use tokio::sync::mpsc::UnboundedSender; +use super::chat_messages::{AssistantChatMessage, ChatMessage, ContentBlock, MixedContent}; use super::llm::{ChatFunction, ChatFunctionCall}; use super::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async}; @@ -50,19 +51,6 @@ impl From for ChatMessageRole { } } -impl TryFrom<&ChatMessageRole> for AnthropicChatMessageRole { - type Error = anyhow::Error; - - fn try_from(value: &ChatMessageRole) -> Result { - match value { - ChatMessageRole::Assistant => Ok(AnthropicChatMessageRole::Assistant), - ChatMessageRole::System => Ok(AnthropicChatMessageRole::User), - ChatMessageRole::User => Ok(AnthropicChatMessageRole::User), - ChatMessageRole::Function => Ok(AnthropicChatMessageRole::User), - } - } -} - impl ToString for AnthropicChatMessageRole { fn to_string(&self) -> String { match self { @@ -197,82 +185,110 @@ impl TryFrom<&ChatMessage> for AnthropicChatMessage { type Error = anyhow::Error; fn try_from(cm: &ChatMessage) -> Result { - let role = AnthropicChatMessageRole::try_from(&cm.role) - .map_err(|e| anyhow!("Error converting role: {:?}", e))?; - - // Handling meta prompt. - let meta_prompt = match cm.role { - // If function_call_id is not set Anthropic has no way to correlate the tool_use with - // the tool_result. - ChatMessageRole::Function if cm.function_call_id.is_none() => cm - .name - .as_ref() - .map_or(String::new(), |name| format!("[tool: {}] ", name)), - _ => String::new(), - }; + match cm { + ChatMessage::Assistant(assistant_msg) => { + // Handling tool_uses. + let tool_uses = match assistant_msg.function_calls.as_ref() { + Some(fc) => Some( + fc.iter() + .map(|function_call| { + let value = serde_json::from_str(function_call.arguments.as_str())?; + + Ok(AnthropicContent { + r#type: AnthropicContentType::ToolUse, + text: None, + tool_use: Some(AnthropicContentToolUse { + name: function_call.name.clone(), + id: function_call.id.clone(), + input: value, + }), + tool_result: None, + }) + }) + .collect::>>()?, + ), + None => None, + }; + + // Handling text. + let text = assistant_msg.content.as_ref().map(|text| AnthropicContent { + r#type: AnthropicContentType::Text, + text: Some(text.clone()), + tool_result: None, + tool_use: None, + }); + + // Combining all content into one vector using iterators. + let content_vec = text + .into_iter() + .chain(tool_uses.into_iter().flatten()) + .collect::>(); - // Handling tool_uses. - let tool_uses = match cm.function_calls.as_ref() { - Some(fc) => Some( - fc.iter() - .map(|function_call| { - let value = serde_json::from_str(function_call.arguments.as_str())?; - - Ok(AnthropicContent { - r#type: AnthropicContentType::ToolUse, - text: None, - tool_use: Some(AnthropicContentToolUse { - name: function_call.name.clone(), - id: function_call.id.clone(), - input: value, + Ok(AnthropicChatMessage { + content: content_vec, + role: AnthropicChatMessageRole::Assistant, + }) + } + ChatMessage::Function(function_msg) => { + // Handling tool_result. + let tool_result = AnthropicContent { + r#type: AnthropicContentType::ToolResult, + tool_use: None, + tool_result: Some(AnthropicContentToolResult { + tool_use_id: function_msg.function_call_id.clone(), + // TODO(2024-06-24 flav) This does not need to be Optionable. + content: Some(function_msg.content.clone()), + }), + text: None, + }; + + Ok(AnthropicChatMessage { + content: vec![tool_result], + role: AnthropicChatMessageRole::User, + }) + } + ChatMessage::User(user_msg) => match &user_msg.content { + ContentBlock::Mixed(m) => { + let content: Vec = m + .into_iter() + .map(|mb| match mb { + MixedContent::TextContent(tc) => Ok(AnthropicContent { + r#type: AnthropicContentType::Text, + text: Some(tc.text.clone()), + tool_result: None, + tool_use: None, }), - tool_result: None, + MixedContent::ImageContent(_) => { + Err(anyhow!("Vision is not supported for Anthropic.")) + } }) - }) - .collect::>>()?, - ), - None => None, - }; + .collect::>>()?; - // Handling tool_result. - let tool_result = match cm.function_call_id.as_ref() { - Some(fcid) => Some(AnthropicContent { - r#type: AnthropicContentType::ToolResult, - tool_use: None, - tool_result: Some(AnthropicContentToolResult { - tool_use_id: fcid.clone(), - content: cm.content.clone(), + Ok(AnthropicChatMessage { + content, + role: AnthropicChatMessageRole::User, + }) + } + ContentBlock::Text(t) => Ok(AnthropicChatMessage { + content: vec![AnthropicContent { + r#type: AnthropicContentType::Text, + text: Some(t.clone()), + tool_result: None, + tool_use: None, + }], + role: AnthropicChatMessageRole::User, }), - text: None, - }), - None => None, - }; - - // Handling text. - let text = cm - .function_call_id - .is_none() - .then(|| { - cm.content.as_ref().map(|text| AnthropicContent { + }, + ChatMessage::System(system_msg) => Ok(AnthropicChatMessage { + content: vec![AnthropicContent { r#type: AnthropicContentType::Text, - text: Some(format!("{}{}", meta_prompt, text)), + text: Some(system_msg.content.clone()), tool_result: None, tool_use: None, - }) - }) - .flatten(); - - // Combining all content into one vector using iterators. - let content_vec = text - .into_iter() - .chain(tool_uses.into_iter().flatten()) - .chain(tool_result.into_iter()) - .collect::>(); - - Ok(AnthropicChatMessage { - content: content_vec, - role, - }) + }], + role: AnthropicChatMessageRole::User, + }), + } } } @@ -387,7 +403,7 @@ impl TryFrom for ChatResponse { // It takes the first tool call from the vector of AnthropicResponseContent, // potentially discarding others. Anthropic often returns the CoT content as a first message, // which gets combined with the first tool call in the resulting ChatMessage. -impl TryFrom for ChatMessage { +impl TryFrom for AssistantChatMessage { type Error = anyhow::Error; fn try_from(cr: ChatResponse) -> Result { @@ -426,13 +442,12 @@ impl TryFrom for ChatMessage { None }; - Ok(ChatMessage { + Ok(AssistantChatMessage { role: ChatMessageRole::Assistant, name: None, content: text_content, function_call, function_calls, - function_call_id: None, }) } } @@ -1543,11 +1558,8 @@ impl LLM for AnthropicLLM { } let system = match messages.get(0) { - Some(cm) => match cm.role { - ChatMessageRole::System => match cm.content.as_ref() { - Some(c) => Some(c.clone()), - None => None, - }, + Some(cm) => match cm { + ChatMessage::System(system_msg) => Some(system_msg.content.clone()), _ => None, }, None => None, @@ -1639,7 +1651,7 @@ impl LLM for AnthropicLLM { prompt_tokens: c.usage.input_tokens, completion_tokens: c.usage.output_tokens, }), - completions: ChatMessage::try_from(c).into_iter().collect(), + completions: AssistantChatMessage::try_from(c).into_iter().collect(), provider_request_id: request_id, }) } diff --git a/core/src/providers/azure_openai.rs b/core/src/providers/azure_openai.rs index bc409c7b3b07..3419e4ac8e7d 100644 --- a/core/src/providers/azure_openai.rs +++ b/core/src/providers/azure_openai.rs @@ -1,7 +1,8 @@ +use crate::providers::chat_messages::AssistantChatMessage; use crate::providers::embedder::{Embedder, EmbedderVector}; use crate::providers::llm::ChatFunction; use crate::providers::llm::Tokens; -use crate::providers::llm::{ChatMessage, LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM}; +use crate::providers::llm::{LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM}; use crate::providers::openai::{ chat_completion, completion, embed, streamed_chat_completion, streamed_completion, to_openai_messages, OpenAILLM, OpenAITool, OpenAIToolChoice, @@ -26,6 +27,8 @@ use std::str::FromStr; use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; +use super::chat_messages::ChatMessage; + #[derive(Serialize, Deserialize, Debug, Clone)] struct AzureOpenAIScaleSettings { scale_type: String, @@ -546,7 +549,7 @@ impl LLM for AzureOpenAILLM { completions: c .choices .iter() - .map(|c| ChatMessage::try_from(&c.message)) + .map(|c| AssistantChatMessage::try_from(&c.message)) .collect::>>()?, usage: c.usage.map(|usage| LLMTokenUsage { prompt_tokens: usage.prompt_tokens, diff --git a/core/src/providers/chat_messages.rs b/core/src/providers/chat_messages.rs new file mode 100644 index 000000000000..cc1c4d574e2a --- /dev/null +++ b/core/src/providers/chat_messages.rs @@ -0,0 +1,155 @@ +use serde::{Deserialize, Deserializer, Serialize}; +use serde_json::Value; + +use super::llm::{ChatFunctionCall, ChatMessageRole}; + +// User message. + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "snake_case")] +pub enum TextContentType { + Text, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct TextContent { + #[serde(rename = "type")] + pub r#type: TextContentType, + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct ImageUrlContent { + pub url: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ImageContentType { + ImageUrl, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct ImageContent { + pub r#type: ImageContentType, + pub image_url: ImageUrlContent, +} + +// Define an enum for mixed content +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[serde(untagged)] +pub enum MixedContent { + TextContent(TextContent), + ImageContent(ImageContent), +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[serde(untagged)] +pub enum ContentBlock { + Text(String), + Mixed(Vec), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(deny_unknown_fields)] +pub struct UserChatMessage { + pub content: ContentBlock, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + pub role: ChatMessageRole, +} + +// System message. + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(deny_unknown_fields)] +pub struct SystemChatMessage { + pub content: String, + pub role: ChatMessageRole, +} + +// Assistant message. + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(deny_unknown_fields)] +pub struct AssistantChatMessage { + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + pub role: ChatMessageRole, +} + +// Function message. + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(deny_unknown_fields)] +pub struct FunctionChatMessage { + pub content: String, + pub function_call_id: String, + pub name: Option, + pub role: ChatMessageRole, +} + +// Enum representing different types of chat messages, where the `role` field +// (mapped to ChatMessageRole) is used to determine the specific variant. + +#[derive(Debug, Serialize, PartialEq, Clone)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum ChatMessage { + Assistant(AssistantChatMessage), + Function(FunctionChatMessage), + User(UserChatMessage), + System(SystemChatMessage), +} + +impl<'de> Deserialize<'de> for ChatMessage { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let v: Value = Value::deserialize(deserializer)?; + let role = v["role"] + .as_str() + .ok_or_else(|| serde::de::Error::custom("role field missing"))?; + + match role { + "assistant" => { + let chat_msg: AssistantChatMessage = + serde_json::from_value(v).map_err(serde::de::Error::custom)?; + Ok(ChatMessage::Assistant(chat_msg)) + } + "function" => { + let chat_msg: FunctionChatMessage = + serde_json::from_value(v).map_err(serde::de::Error::custom)?; + Ok(ChatMessage::Function(chat_msg)) + } + "user" => { + let chat_msg: UserChatMessage = + serde_json::from_value(v).map_err(serde::de::Error::custom)?; + Ok(ChatMessage::User(chat_msg)) + } + "system" => { + let chat_msg: SystemChatMessage = + serde_json::from_value(v).map_err(serde::de::Error::custom)?; + Ok(ChatMessage::System(chat_msg)) + } + _ => Err(serde::de::Error::custom(format!("Invalid role: {}", role))), + } + } +} + +impl ChatMessage { + pub fn get_role(&self) -> Option<&ChatMessageRole> { + match self { + ChatMessage::Assistant(msg) => Some(&msg.role), + ChatMessage::Function(msg) => Some(&msg.role), + ChatMessage::User(msg) => Some(&msg.role), + ChatMessage::System(msg) => Some(&msg.role), + } + } +} diff --git a/core/src/providers/google_ai_studio.rs b/core/src/providers/google_ai_studio.rs index 0c0d5ec7f142..14dcb19cb135 100644 --- a/core/src/providers/google_ai_studio.rs +++ b/core/src/providers/google_ai_studio.rs @@ -13,6 +13,7 @@ use tokio::sync::mpsc::UnboundedSender; use crate::{ providers::{ + chat_messages::AssistantChatMessage, llm::Tokens, provider::{ModelError, ModelErrorRetryOptions}, }, @@ -21,10 +22,11 @@ use crate::{ }; use super::{ + chat_messages::{ChatMessage, ContentBlock, FunctionChatMessage, MixedContent}, embedder::Embedder, llm::{ - ChatFunction, ChatFunctionCall, ChatMessage, ChatMessageRole, LLMChatGeneration, - LLMGeneration, LLMTokenUsage, LLM, + ChatFunction, ChatFunctionCall, ChatMessageRole, LLMChatGeneration, LLMGeneration, + LLMTokenUsage, LLM, }, provider::{Provider, ProviderID}, tiktoken::tiktoken::{ @@ -55,16 +57,16 @@ pub struct GoogleAIStudioFunctionResponse { response: GoogleAiStudioFunctionResponseContent, } -impl TryFrom<&ChatMessage> for GoogleAIStudioFunctionResponse { +impl TryFrom<&FunctionChatMessage> for GoogleAIStudioFunctionResponse { type Error = anyhow::Error; - fn try_from(m: &ChatMessage) -> Result { + fn try_from(m: &FunctionChatMessage) -> Result { let name = m.name.clone().unwrap_or_default(); Ok(GoogleAIStudioFunctionResponse { name: name.clone(), response: GoogleAiStudioFunctionResponseContent { name, - content: m.content.clone().unwrap_or_default(), + content: m.content.clone(), }, }) } @@ -144,50 +146,101 @@ pub struct Content { impl TryFrom<&ChatMessage> for Content { type Error = anyhow::Error; - fn try_from(m: &ChatMessage) -> Result { - let role = match m.role { - ChatMessageRole::Assistant => String::from("model"), - ChatMessageRole::Function => String::from("function"), - _ => String::from("user"), - }; + fn try_from(cm: &ChatMessage) -> Result { + match cm { + ChatMessage::Assistant(assistant_msg) => { + let parts = match assistant_msg.function_calls { + Some(ref fcs) => fcs + .iter() + .map(|fc| { + Ok(Part { + text: assistant_msg.content.clone(), + function_call: Some(GoogleAIStudioFunctionCall::try_from(fc)?), + function_response: None, + }) + }) + .collect::, anyhow::Error>>()?, + None => { + if let Some(ref fc) = assistant_msg.function_call { + vec![Part { + text: assistant_msg.content.clone(), + function_call: Some(GoogleAIStudioFunctionCall::try_from(fc)?), + function_response: None, + }] + } else { + vec![Part { + text: assistant_msg.content.clone(), + function_call: None, + function_response: None, + }] + } + } + }; - let parts = match m.function_calls { - Some(ref fcs) => fcs - .iter() - .map(|fc| { - Ok(Part { - text: m.content.clone(), - function_call: Some(GoogleAIStudioFunctionCall::try_from(fc)?), - function_response: None, - }) + Ok(Content { + role: String::from("model"), + parts: Some(parts), }) - .collect::>>()?, - None => { - vec![Part { - text: match m.role { - // System is passed as a Content. We transform it here but it will be removed - // from the list of messages and passed as separate argument to the API. - ChatMessageRole::System => m.content.clone(), - ChatMessageRole::User => m.content.clone(), - ChatMessageRole::Assistant => m.content.clone(), - _ => None, - }, - + } + ChatMessage::Function(function_msg) => Ok(Content { + role: String::from("function"), + parts: Some(vec![Part { + text: None, function_call: None, - function_response: match m.role { - ChatMessageRole::Function => { - GoogleAIStudioFunctionResponse::try_from(m).ok() + function_response: GoogleAIStudioFunctionResponse::try_from(function_msg).ok(), + }]), + }), + ChatMessage::User(user_msg) => { + let text = match &user_msg.content { + ContentBlock::Mixed(m) => { + let result = m.iter().enumerate().try_fold( + String::new(), + |mut acc, (i, content)| { + match content { + MixedContent::ImageContent(_) => Err(anyhow!( + "Vision is not supported for Google AI Studio." + )), + MixedContent::TextContent(tc) => { + acc.push_str(&tc.text.trim()); + if i != m.len() - 1 { + // Add newline if it's not the last item. + acc.push('\n'); + } + Ok(acc) + } + } + }, + ); + + match result { + Ok(text) if !text.is_empty() => Ok(text), + Ok(_) => Err(anyhow!("Text is required.")), // Empty string. + Err(e) => Err(e), } - _ => None, - }, - }] + } + ContentBlock::Text(t) => Ok(t.clone()), + }?; + + Ok(Content { + role: String::from("user"), + parts: Some(vec![Part { + text: Some(text), + function_call: None, + function_response: None, + }]), + }) } - }; - - Ok(Content { - role, - parts: Some(parts), - }) + ChatMessage::System(system_msg) => Ok(Content { + role: String::from("user"), + parts: Some(vec![Part { + // System is passed as a Content. We transform it here but it will be removed + // from the list of messages and passed as separate argument to the API. + text: Some(system_msg.content.clone()), + function_call: None, + function_response: None, + }]), + }), + } } } @@ -475,8 +528,8 @@ impl LLM for GoogleAiStudioLLM { // Remove system message if first. let system = match messages.get(0) { - Some(cm) => match cm.role { - ChatMessageRole::System => Some(Content::try_from(cm)?), + Some(cm) => match cm { + ChatMessage::System(_) => Some(Content::try_from(cm)?), _ => None, }, None => None, @@ -576,7 +629,7 @@ impl LLM for GoogleAiStudioLLM { created: utils::now(), provider: ProviderID::GoogleAiStudio.to_string(), model: self.id().clone(), - completions: vec![ChatMessage { + completions: vec![AssistantChatMessage { name: None, function_call: match function_calls.first() { Some(fc) => Some(fc.clone()), @@ -586,7 +639,6 @@ impl LLM for GoogleAiStudioLLM { 0 => None, _ => Some(function_calls), }, - function_call_id: None, role: ChatMessageRole::Assistant, content, }], diff --git a/core/src/providers/llm.rs b/core/src/providers/llm.rs index 4a7d3502e39e..360839fbbab0 100644 --- a/core/src/providers/llm.rs +++ b/core/src/providers/llm.rs @@ -13,6 +13,8 @@ use std::str::FromStr; use tokio::sync::mpsc::UnboundedSender; use tracing::{error, info}; +use super::chat_messages::{AssistantChatMessage, ChatMessage}; + #[derive(Debug, Serialize, PartialEq, Clone, Deserialize)] pub struct Tokens { pub text: String, @@ -72,20 +74,6 @@ pub struct ChatFunctionCall { pub name: String, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct ChatMessage { - pub role: ChatMessageRole, - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - pub content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub function_call: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub function_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub function_call_id: Option, -} - #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] pub struct ChatFunction { pub name: String, @@ -104,7 +92,7 @@ pub struct LLMChatGeneration { pub created: u64, pub provider: String, pub model: String, - pub completions: Vec, + pub completions: Vec, pub usage: Option, pub provider_request_id: Option, } diff --git a/core/src/providers/mistral.rs b/core/src/providers/mistral.rs index d1a356b037ec..670072be147e 100644 --- a/core/src/providers/mistral.rs +++ b/core/src/providers/mistral.rs @@ -1,4 +1,5 @@ -use super::llm::{ChatFunction, ChatFunctionCall, ChatMessage}; +use super::chat_messages::{AssistantChatMessage, ChatMessage, ContentBlock, MixedContent}; +use super::llm::{ChatFunction, ChatFunctionCall}; use super::sentencepiece::sentencepiece::{ decode_async, encode_async, mistral_instruct_tokenizer_240216_model_v2_base_singleton, mistral_instruct_tokenizer_240216_model_v3_base_singleton, @@ -63,19 +64,6 @@ impl From for ChatMessageRole { } } -impl TryFrom<&ChatMessageRole> for MistralChatMessageRole { - type Error = anyhow::Error; - - fn try_from(value: &ChatMessageRole) -> Result { - match value { - ChatMessageRole::Assistant => Ok(MistralChatMessageRole::Assistant), - ChatMessageRole::System => Ok(MistralChatMessageRole::System), - ChatMessageRole::User => Ok(MistralChatMessageRole::User), - ChatMessageRole::Function => Ok(MistralChatMessageRole::Tool), - } - } -} - impl ToString for MistralChatMessageRole { fn to_string(&self) -> String { match self { @@ -104,8 +92,6 @@ struct MistralChatMessage { pub role: MistralChatMessageRole, pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, @@ -158,37 +144,71 @@ impl TryFrom<&ChatMessage> for MistralChatMessage { type Error = anyhow::Error; fn try_from(cm: &ChatMessage) -> Result { - let mistral_role = MistralChatMessageRole::try_from(&cm.role) - .map_err(|e| anyhow!("Error converting role: {:?}", e))?; + match cm { + ChatMessage::Assistant(assistant_msg) => Ok(MistralChatMessage { + role: MistralChatMessageRole::Assistant, + content: assistant_msg.content.clone(), + tool_calls: match assistant_msg.function_calls.as_ref() { + Some(fc) => Some( + fc.iter() + .map(|f| MistralToolCall::try_from(f)) + .collect::>>()?, + ), + None => None, + }, + tool_call_id: None, + }), + ChatMessage::Function(function_msg) => Ok(MistralChatMessage { + role: MistralChatMessageRole::Tool, + content: Some(function_msg.content.clone()), + tool_calls: None, + tool_call_id: Some(sanitize_tool_call_id(&function_msg.function_call_id)), + }), + ChatMessage::User(user_msg) => Ok(MistralChatMessage { + role: MistralChatMessageRole::User, + content: match &user_msg.content { + ContentBlock::Mixed(m) => { + let result = m.iter().enumerate().try_fold( + String::new(), + |mut acc, (i, content)| { + match content { + MixedContent::ImageContent(_) => { + Err(anyhow!("Vision is not supported for Mistral.")) + } + MixedContent::TextContent(tc) => { + acc.push_str(&tc.text); + if i != m.len() - 1 { + // Add newline if it's not the last item. + acc.push('\n'); + } + Ok(acc) + } + } + }, + ); - Ok(MistralChatMessage { - content: match cm.content.as_ref() { - Some(c) => Some(c.clone()), - None => None, - }, - // Name is only supported for the Function/Tool role. - name: match mistral_role { - MistralChatMessageRole::Tool => cm.name.clone(), - _ => None, - }, - role: mistral_role, - tool_calls: match cm.function_calls.as_ref() { - Some(fc) => Some( - fc.iter() - .map(|f| MistralToolCall::try_from(f)) - .collect::>>()?, - ), - None => None, - }, - tool_call_id: cm - .function_call_id - .as_ref() - .map(|id| sanitize_tool_call_id(id)), - }) + match result { + Ok(text) if !text.is_empty() => Some(text), + Ok(_) => None, // Empty string. + Err(e) => return Err(e), + } + } + ContentBlock::Text(t) => Some(t.clone()), + }, + tool_calls: None, + tool_call_id: None, + }), + ChatMessage::System(system_msg) => Ok(MistralChatMessage { + role: MistralChatMessageRole::System, + content: Some(system_msg.content.clone()), + tool_calls: None, + tool_call_id: None, + }), + } } } -impl TryFrom<&MistralChatMessage> for ChatMessage { +impl TryFrom<&MistralChatMessage> for AssistantChatMessage { type Error = anyhow::Error; fn try_from(cm: &MistralChatMessage) -> Result { @@ -219,13 +239,12 @@ impl TryFrom<&MistralChatMessage> for ChatMessage { None }; - Ok(ChatMessage { + Ok(AssistantChatMessage { content, role, name: None, function_call, function_calls, - function_call_id: None, }) } } @@ -668,7 +687,6 @@ impl MistralAILLM { .map(|c| MistralChatChoice { message: MistralChatMessage { content: Some("".to_string()), - name: None, role: MistralChatMessageRole::Assistant, tool_calls: None, tool_call_id: None, @@ -1012,7 +1030,7 @@ impl LLM for MistralAILLM { completions: c .choices .iter() - .map(|c| ChatMessage::try_from(&c.message)) + .map(|c| AssistantChatMessage::try_from(&c.message)) .collect::>>()?, usage: c.usage.map(|u| LLMTokenUsage { completion_tokens: u.completion_tokens.unwrap_or(0), diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index 5c9a7d525da7..e248f0fc201e 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -2,7 +2,7 @@ use crate::providers::embedder::{Embedder, EmbedderVector}; use crate::providers::llm::Tokens; use crate::providers::llm::{ChatFunction, ChatFunctionCall}; use crate::providers::llm::{ - ChatMessage, ChatMessageRole, LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM, + ChatMessageRole, LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM, }; use crate::providers::provider::{ModelError, ModelErrorRetryOptions, Provider, ProviderID}; use crate::providers::tiktoken::tiktoken::{ @@ -32,6 +32,8 @@ use std::time::Duration; use tokio::sync::mpsc::UnboundedSender; use tokio::time::timeout; +use super::chat_messages::{AssistantChatMessage, ChatMessage, ContentBlock, MixedContent}; + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Usage { pub prompt_tokens: u64, @@ -244,10 +246,51 @@ impl From for ChatMessageRole { } } +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "snake_case")] +pub enum OpenAITextContentType { + Text, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAITextContent { + #[serde(rename = "type")] + pub r#type: OpenAITextContentType, + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIImageUrlContent { + pub url: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "snake_case")] +pub enum OpenAIImageContentType { + ImageUrl, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIImageContent { + pub r#type: OpenAIImageContentType, + pub image_url: OpenAIImageUrlContent, +} + +// Define an enum for mixed content +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[serde(untagged)] +pub enum OpenAIContentBlock { + TextContent(OpenAITextContent), + ImageContent(OpenAIImageContent), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAIContentBlockVec(Vec); + #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] pub struct OpenAIChatMessage { pub role: OpenAIChatMessageRole, - pub content: Option, + pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -256,9 +299,22 @@ pub struct OpenAIChatMessage { pub tool_call_id: Option, } +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct OpenAICompletionChatMessage { + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + pub role: OpenAIChatMessageRole, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct OpenAIChatChoice { - pub message: OpenAIChatMessage, + pub message: OpenAICompletionChatMessage, pub index: usize, pub finish_reason: Option, } @@ -275,10 +331,10 @@ pub struct OpenAIChatCompletion { // This code performs a type conversion with information loss when converting to ChatFunctionCall. // It only supports one tool call, so it takes the first one from the vector of OpenAIToolCall, // hence potentially discarding other tool calls. -impl TryFrom<&OpenAIChatMessage> for ChatMessage { +impl TryFrom<&OpenAICompletionChatMessage> for AssistantChatMessage { type Error = anyhow::Error; - fn try_from(cm: &OpenAIChatMessage) -> Result { + fn try_from(cm: &OpenAICompletionChatMessage) -> Result { let role = ChatMessageRole::from(cm.role.clone()); let content = match cm.content.as_ref() { Some(c) => Some(c.clone()), @@ -311,49 +367,111 @@ impl TryFrom<&OpenAIChatMessage> for ChatMessage { None => None, }; - Ok(ChatMessage { + Ok(AssistantChatMessage { content, role, name, function_call, function_calls, - function_call_id: None, }) } } +impl TryFrom<&ContentBlock> for OpenAIContentBlockVec { + type Error = anyhow::Error; + + fn try_from(cm: &ContentBlock) -> Result { + match cm { + ContentBlock::Text(t) => Ok(OpenAIContentBlockVec(vec![ + OpenAIContentBlock::TextContent(OpenAITextContent { + r#type: OpenAITextContentType::Text, + text: t.clone(), + }), + ])), + ContentBlock::Mixed(m) => { + let content: Vec = m + .into_iter() + .map(|mb| match mb { + MixedContent::TextContent(tc) => { + Ok(OpenAIContentBlock::TextContent(OpenAITextContent { + r#type: OpenAITextContentType::Text, + text: tc.text.clone(), + })) + } + MixedContent::ImageContent(ic) => { + Ok(OpenAIContentBlock::ImageContent(OpenAIImageContent { + r#type: OpenAIImageContentType::ImageUrl, + image_url: OpenAIImageUrlContent { + url: ic.image_url.url.clone(), + }, + })) + } + }) + .collect::>>()?; + + Ok(OpenAIContentBlockVec(content)) + } + } + } +} + +impl TryFrom<&String> for OpenAIContentBlockVec { + type Error = anyhow::Error; + + fn try_from(t: &String) -> Result { + Ok(OpenAIContentBlockVec(vec![ + OpenAIContentBlock::TextContent(OpenAITextContent { + r#type: OpenAITextContentType::Text, + text: t.clone(), + }), + ])) + } +} + impl TryFrom<&ChatMessage> for OpenAIChatMessage { type Error = anyhow::Error; fn try_from(cm: &ChatMessage) -> Result { - // If `function_call_id` is present, `role` must be `function` and should be mapped to `Tool`. - // This is to maintain backward compatibility with the original `function` role used for content fragments. - let (role, tool_call_id) = match cm.function_call_id.as_ref() { - Some(fcid) => match OpenAIChatMessageRole::from(&cm.role) { - OpenAIChatMessageRole::Function => { - Ok((OpenAIChatMessageRole::Tool, Some(fcid.clone()))) - } - _ => Err(anyhow!( - "`function_call_id` is provided but `role` is not set to `function`" - )), - }, - _ => Ok((OpenAIChatMessageRole::from(&cm.role), None)), - }?; - - Ok(OpenAIChatMessage { - content: cm.content.clone(), - name: cm.name.clone(), - role, - tool_call_id, - tool_calls: match cm.function_calls.as_ref() { - Some(fc) => Some( - fc.into_iter() - .map(|f| OpenAIToolCall::try_from(f)) - .collect::, _>>()?, - ), - None => None, - }, - }) + match cm { + ChatMessage::Assistant(assistant_msg) => Ok(OpenAIChatMessage { + content: match &assistant_msg.content { + Some(c) => Some(OpenAIContentBlockVec::try_from(c)?), + None => None, + }, + name: assistant_msg.name.clone(), + role: OpenAIChatMessageRole::from(&assistant_msg.role), + tool_calls: match assistant_msg.function_calls.as_ref() { + Some(fc) => Some( + fc.into_iter() + .map(|f| OpenAIToolCall::try_from(f)) + .collect::, _>>()?, + ), + None => None, + }, + tool_call_id: None, + }), + ChatMessage::Function(function_msg) => Ok(OpenAIChatMessage { + content: Some(OpenAIContentBlockVec::try_from(&function_msg.content)?), + name: None, + role: OpenAIChatMessageRole::Tool, + tool_calls: None, + tool_call_id: Some(function_msg.function_call_id.clone()), + }), + ChatMessage::System(system_msg) => Ok(OpenAIChatMessage { + content: Some(OpenAIContentBlockVec::try_from(&system_msg.content)?), + name: None, + role: OpenAIChatMessageRole::from(&system_msg.role), + tool_calls: None, + tool_call_id: None, + }), + ChatMessage::User(user_msg) => Ok(OpenAIChatMessage { + content: Some(OpenAIContentBlockVec::try_from(&user_msg.content)?), + name: user_msg.name.clone(), + role: OpenAIChatMessageRole::from(&user_msg.role), + tool_calls: None, + tool_call_id: None, + }), + } } } @@ -1161,7 +1279,7 @@ pub async fn streamed_chat_completion( .choices .iter() .map(|c| OpenAIChatChoice { - message: OpenAIChatMessage { + message: OpenAICompletionChatMessage { content: Some("".to_string()), name: None, role: OpenAIChatMessageRole::System, @@ -1963,7 +2081,7 @@ impl LLM for OpenAILLM { completions: c .choices .iter() - .map(|c| ChatMessage::try_from(&c.message)) + .map(|c| AssistantChatMessage::try_from(&c.message)) .collect::>>()?, usage: c.usage.map(|usage| LLMTokenUsage { prompt_tokens: usage.prompt_tokens, diff --git a/core/src/providers/provider.rs b/core/src/providers/provider.rs index 550e36ee8fb1..90fb49c7c9bd 100644 --- a/core/src/providers/provider.rs +++ b/core/src/providers/provider.rs @@ -147,10 +147,10 @@ pub trait Provider { pub fn provider(t: ProviderID) -> Box { match t { - ProviderID::OpenAI => Box::new(OpenAIProvider::new()), - ProviderID::AzureOpenAI => Box::new(AzureOpenAIProvider::new()), ProviderID::Anthropic => Box::new(AnthropicProvider::new()), - ProviderID::Mistral => Box::new(MistralProvider::new()), + ProviderID::AzureOpenAI => Box::new(AzureOpenAIProvider::new()), ProviderID::GoogleAiStudio => Box::new(GoogleAiStudioProvider::new()), + ProviderID::Mistral => Box::new(MistralProvider::new()), + ProviderID::OpenAI => Box::new(OpenAIProvider::new()), } }