Skip to content

Commit

Permalink
Refactor llm block serde types (#5835)
Browse files Browse the repository at this point in the history
* Refactor llm block serde types

* ✨

* 📖

* 👕

* Use vec for MixedContentBlock

* Strengthen types

* s/llm_message/chat_messages + remove system.name

* Use new lines to join texts

* Implement Open AI own content block type.

* ✂️

* ✨

* Dont' add trailing new line
  • Loading branch information
flvndvd authored Jun 25, 2024
1 parent 473946b commit b45c344
Show file tree
Hide file tree
Showing 11 changed files with 604 additions and 348 deletions.
107 changes: 9 additions & 98 deletions core/src/blocks/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ use crate::blocks::block::{
parse_pair, replace_variables_in_string, Block, BlockResult, BlockType, Env,
};
use crate::deno::script::Script;
use crate::providers::llm::{
ChatFunction, ChatFunctionCall, ChatMessage, ChatMessageRole, LLMChatRequest,
};
use crate::providers::chat_messages::{AssistantChatMessage, ChatMessage, SystemChatMessage};
use crate::providers::llm::{ChatFunction, ChatMessageRole, LLMChatRequest};
use crate::providers::provider::ProviderID;
use crate::Rule;
use anyhow::{anyhow, Result};
Expand Down Expand Up @@ -129,7 +128,7 @@ impl Chat {

#[derive(Debug, Serialize, PartialEq)]
struct ChatValue {
message: ChatMessage,
message: AssistantChatMessage,
}

#[async_trait]
Expand Down Expand Up @@ -316,93 +315,9 @@ impl Block for Chat {
expecting an array of objects with fields `role`, possibly `name`, \
and `content` or `function_call(s)`.";

let mut messages = match messages_value {
Value::Array(a) => a
.into_iter()
.map(|v| match v {
Value::Object(o) => {
match (
o.get("role"),
o.get("content"),
o.get("function_call"),
o.get("function_calls"),
) {
(Some(Value::String(r)), Some(Value::String(c)), None, None) => {
Ok(ChatMessage {
role: ChatMessageRole::from_str(r)?,
name: match o.get("name") {
Some(Value::String(n)) => Some(n.clone()),
_ => None,
},
content: Some(c.clone()),
function_call: None,
function_calls: None,
function_call_id: match o.get("function_call_id") {
Some(Value::String(fcid)) => Some(fcid.to_string()),
_ => None,
},
})
}
(Some(Value::String(r)), None, Some(Value::Object(fc)), None) => {
// parse function call into ChatFunctionCall
match (fc.get("name"), fc.get("arguments"), fc.get("id")) {
(
Some(Value::String(n)),
Some(Value::String(a)),
Some(Value::String(id)),
) => Ok(ChatMessage {
role: ChatMessageRole::from_str(r)?,
name: match o.get("name") {
Some(Value::String(n)) => Some(n.clone()),
_ => None,
},
content: None,
function_call: Some(ChatFunctionCall {
id: id.clone(),
name: n.clone(),
arguments: a.clone(),
}),
function_calls: None,
function_call_id: None,
}),
_ => Err(anyhow!(MESSAGES_CODE_OUTPUT)),
}
}
(Some(Value::String(r)), None, None, Some(Value::Array(fcs))) => {
let function_calls = fcs
.into_iter()
.map(|fc| {
match (fc.get("name"), fc.get("arguments"), fc.get("id")) {
(
Some(Value::String(n)),
Some(Value::String(a)),
Some(Value::String(id)),
) => Ok(ChatFunctionCall {
id: id.clone(),
name: n.clone(),
arguments: a.clone(),
}),
_ => Err(anyhow!(MESSAGES_CODE_OUTPUT)),
}
})
.collect::<Result<Vec<_>, _>>()?;

Ok(ChatMessage {
role: ChatMessageRole::from_str(r)?,
name: None,
content: None,
function_call: None,
function_calls: Some(function_calls),
function_call_id: None,
})
}
_ => Err(anyhow!(MESSAGES_CODE_OUTPUT)),
}
}
_ => Err(anyhow!(MESSAGES_CODE_OUTPUT)),
})
.collect::<Result<Vec<ChatMessage>>>()?,
_ => Err(anyhow!(MESSAGES_CODE_OUTPUT))?,
let mut messages: Vec<ChatMessage> = match messages_value {
Value::Array(a) => serde_json::from_value(Value::Array(a))?,
_ => return Err(anyhow!(MESSAGES_CODE_OUTPUT)),
};

// Process functions.
Expand Down Expand Up @@ -481,14 +396,10 @@ impl Block for Chat {
if i.len() > 0 {
messages.insert(
0,
ChatMessage {
ChatMessage::System(SystemChatMessage {
role: ChatMessageRole::System,
name: None,
content: Some(i),
function_call: None,
function_calls: None,
function_call_id: None,
},
content: i,
}),
);
}

Expand Down
12 changes: 5 additions & 7 deletions core/src/blocks/llm.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::blocks::block::{
find_variables, parse_pair, replace_variables_in_string, Block, BlockResult, BlockType, Env,
};
use crate::providers::llm::{ChatMessage, ChatMessageRole, LLMChatRequest, LLMRequest, Tokens};
use crate::providers::chat_messages::{ChatMessage, ContentBlock, UserChatMessage};
use crate::providers::llm::{ChatMessageRole, LLMChatRequest, LLMRequest, Tokens};
use crate::providers::provider::ProviderID;
use crate::Rule;
use anyhow::{anyhow, Result};
Expand Down Expand Up @@ -422,14 +423,11 @@ impl Block for LLM {
{
true => {
let prompt = self.prompt(env)?;
let messages = vec![ChatMessage {
let messages = vec![ChatMessage::User(UserChatMessage {
role: ChatMessageRole::User,
name: None,
content: Some(prompt.clone()),
function_call: None,
function_calls: None,
function_call_id: None,
}];
content: ContentBlock::Text(prompt.clone()),
})];

let request = LLMChatRequest::new(
provider_id,
Expand Down
1 change: 1 addition & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub mod providers {
pub mod mistral;
pub mod openai;

pub mod chat_messages;
pub mod provider;
pub mod tiktoken {
pub mod tiktoken;
Expand Down
Loading

0 comments on commit b45c344

Please sign in to comment.