diff --git a/core/src/providers/anthropic.rs b/core/src/providers/anthropic.rs index 90f8989d94eb..15ca097683b6 100644 --- a/core/src/providers/anthropic.rs +++ b/core/src/providers/anthropic.rs @@ -28,10 +28,110 @@ use super::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async}; pub enum StopReason { StopSequence, MaxTokens, + EndTurn, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "lowercase")] +pub enum AnthropicChatMessageRole { + Assistant, + User, +} + +impl From for ChatMessageRole { + fn from(value: AnthropicChatMessageRole) -> Self { + match value { + AnthropicChatMessageRole::Assistant => ChatMessageRole::Assistant, + AnthropicChatMessageRole::User => ChatMessageRole::User, + } + } +} + +impl TryFrom<&ChatMessageRole> for AnthropicChatMessageRole { + type Error = anyhow::Error; + + fn try_from(value: &ChatMessageRole) -> Result { + match value { + ChatMessageRole::Assistant => Ok(AnthropicChatMessageRole::Assistant), + ChatMessageRole::System => Ok(AnthropicChatMessageRole::User), + ChatMessageRole::User => Ok(AnthropicChatMessageRole::User), + ChatMessageRole::Function => Ok(AnthropicChatMessageRole::User), + } + } +} + +impl ToString for AnthropicChatMessageRole { + fn to_string(&self) -> String { + match self { + AnthropicChatMessageRole::Assistant => String::from("assistant"), + AnthropicChatMessageRole::User => String::from("user"), + } + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct AnthropicContent { + pub r#type: String, + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct AnthropicChatMessage { + pub content: Vec, + pub role: AnthropicChatMessageRole, +} + +impl TryFrom<&ChatMessage> for AnthropicChatMessage { + type Error = anyhow::Error; + + fn try_from(cm: &ChatMessage) -> Result { + let role = AnthropicChatMessageRole::try_from(&cm.role) + .map_err(|e| anyhow!("Error converting role: {:?}", e))?; + + let meta_prompt = match cm.role { + ChatMessageRole::User => match cm.name.as_ref() { + Some(name) => format!("[user: {}] ", name), // Include space here. + None => String::from(""), + }, + ChatMessageRole::Function => match cm.name.as_ref() { + Some(name) => format!("[function_result: {}] ", name), // Include space here. + None => "[function_result]".to_string(), + }, + _ => String::from(""), + }; + + Ok(AnthropicChatMessage { + content: vec![AnthropicContent { + r#type: "text".to_string(), + text: format!( + "{}{}", + meta_prompt, + cm.content.clone().unwrap_or(String::from("")) + ), + }], + role, + }) + } } #[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Response { +pub struct Usage { + pub input_tokens: u64, + pub output_tokens: u64, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ChatResponse { + pub id: String, + pub model: String, + pub role: AnthropicChatMessageRole, + pub content: Vec, + pub stop_reason: Option, + pub usage: Usage, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct CompletionResponse { pub completion: String, pub stop_reason: Option, pub stop: Option, @@ -50,6 +150,51 @@ pub struct Error { pub error: ErrorDetail, } +// Streaming types + +#[derive(Serialize, Deserialize, Debug, Clone)] +struct StreamMessageStart { + pub r#type: String, + pub message: ChatResponse, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +struct StreamContentBlockStart { + pub r#type: String, + pub index: u64, + pub content_block: AnthropicContent, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +struct StreamContentBlockDelta { + pub r#type: String, + pub index: u64, + pub delta: AnthropicContent, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +struct StreamContentBlockStop { + pub r#type: String, + pub index: u64, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +struct ChatResponseDelta { + stop_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct UsageDelta { + output_tokens: u64, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +struct StreamMessageDelta { + pub r#type: String, + pub delta: ChatResponseDelta, + pub usage: UsageDelta, +} + impl Display for ErrorDetail { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{},{}", self.r#type, self.message) @@ -71,154 +216,353 @@ impl AnthropicLLM { pub fn new(id: String) -> Self { Self { id, api_key: None } } - fn uri(&self) -> Result { - Ok("https://api.anthropic.com/v1/complete" + + fn messages_uri(&self) -> Result { + Ok("https://api.anthropic.com/v1/messages" .to_string() .parse::()?) } - fn chat_prompt(&self, messages: &Vec) -> String { - let mut prompt = messages - .iter() - .map(|cm| -> String { - format!( - "{}: {}{}", - match cm.role { - ChatMessageRole::System => "", - ChatMessageRole::Assistant => "\n\nAssistant", - ChatMessageRole::User => "\n\nHuman", - ChatMessageRole::Function => "\n\nHuman", - }, - match cm.role { - ChatMessageRole::System => "".to_string(), - ChatMessageRole::Assistant => "".to_string(), - ChatMessageRole::User => match cm.name.as_ref() { - Some(name) => format!("[user: {}] ", name), - None => "".to_string(), - }, - ChatMessageRole::Function => match cm.name.as_ref() { - Some(name) => format!("[function_result: {}] ", name), - None => "[function_result]".to_string(), - }, - }, - cm.content.as_ref().unwrap_or(&String::from("")).clone(), - ) - }) - .collect::>() - .join(""); - - prompt = format!("{}\n\nAssistant:", prompt); - - return prompt; + fn completions_uri(&self) -> Result { + Ok("https://api.anthropic.com/v1/complete" + .to_string() + .parse::()?) } async fn chat_completion( &self, - messages: &Vec, + system: Option, + messages: &Vec, temperature: f32, top_p: f32, - stop: &Vec, - mut max_tokens: Option, - ) -> Result { + stop_sequences: &Vec, + max_tokens: i32, + ) -> Result { assert!(self.api_key.is_some()); - let prompt = self.chat_prompt(messages); - - let mut stop_tokens = stop.clone(); - stop_tokens.push(String::from("\n\nHuman:")); - stop_tokens.push(String::from("\n\nAssistant:")); + let mut body = json!({ + "model": self.id.clone(), + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "stop_sequences": match stop_sequences.len() { + 0 => None, + _ => Some(stop_sequences), + }, + }); - if max_tokens.is_none() || max_tokens.unwrap() == -1 { - let tokens = self.encode(&prompt).await?; - max_tokens = Some(std::cmp::min( - (self.context_size() - tokens.len()) as i32, - 16384, - )); + if system.is_some() { + body["system"] = json!(system); } - let response = self - .completion( - self.api_key.clone().unwrap(), - &prompt, - match max_tokens { - Some(m) => m, - None => 256, - }, - temperature, - top_p, - None, - stop_tokens.as_ref(), - ) + let res = reqwest::Client::new() + .post(self.messages_uri()?.to_string()) + .header("Content-Type", "application/json") + .header("X-API-Key", self.api_key.clone().unwrap()) + .header("anthropic-version", "2023-06-01") + .json(&body) + .send() .await?; - return Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::Anthropic.to_string(), - model: self.id.clone(), - completions: vec![ChatMessage { - role: ChatMessageRole::Assistant, - content: Some(response.completion.clone()), - name: None, - function_call: None, - }], - }); + let status = res.status(); + let body = res.bytes().await?; + + let mut b: Vec = vec![]; + body.reader().read_to_end(&mut b)?; + let c: &[u8] = &b; + let response = match status { + reqwest::StatusCode::OK => { + let response: ChatResponse = serde_json::from_slice(c)?; + Ok(response) + } + _ => { + let error: Error = serde_json::from_slice(c)?; + Err(ModelError { + message: format!("Anthropic API Error: {}", error.to_string()), + retryable: None, + }) + } + }?; + + Ok(response) } pub async fn streamed_chat_completion( &self, - messages: &Vec, + system: Option, + messages: &Vec, temperature: f32, top_p: f32, - stop: &Vec, - mut max_tokens: Option, + stop_sequences: &Vec, + max_tokens: i32, event_sender: UnboundedSender, - ) -> Result { - let prompt = self.chat_prompt(messages); - - let mut stop_tokens = stop.clone(); - stop_tokens.push(String::from("\n\nHuman:")); - stop_tokens.push(String::from("\n\nAssistant:")); - - if max_tokens.is_none() || max_tokens.unwrap() == -1 { - let tokens = self.encode(&prompt).await?; - max_tokens = Some(std::cmp::min( - (self.context_size() - tokens.len()) as i32, - 16384, - )); + ) -> Result { + assert!(self.api_key.is_some()); + + let mut body = json!({ + "model": self.id.clone(), + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "stop_sequences": match stop_sequences.len() { + 0 => None, + _ => Some(stop_sequences), + }, + "stream": true, + }); + + if system.is_some() { + body["system"] = json!(system); } - let response = self - .streamed_completion( - self.api_key.clone().unwrap(), - prompt.as_str(), - match max_tokens { - Some(m) => m, - None => 256, - }, - temperature, - top_p, - None, - &stop_tokens, - event_sender, + let https = HttpsConnector::new(); + let url = self.messages_uri()?.to_string(); + + let mut builder = match es::ClientBuilder::for_url(url.as_str()) { + Ok(builder) => builder, + Err(e) => { + return Err(anyhow!( + "Error creating Anthropic streaming client: {:?}", + e + )) + } + }; + + builder = builder.method(String::from("POST")); + builder = match builder.header("Content-Type", "application/json") { + Ok(builder) => builder, + Err(e) => return Err(anyhow!("Error setting header: {:?}", e)), + }; + builder = match builder.header("X-API-Key", self.api_key.clone().unwrap().as_str()) { + Ok(builder) => builder, + Err(e) => return Err(anyhow!("Error setting header: {:?}", e)), + }; + builder = match builder.header("anthropic-version", "2023-06-01") { + Ok(builder) => builder, + Err(e) => return Err(anyhow!("Error setting header: {:?}", e)), + }; + + let client = builder + .body(body.to_string()) + .reconnect( + es::ReconnectOptions::reconnect(true) + .retry_initial(false) + .delay(Duration::from_secs(1)) + .backoff_factor(2) + .delay_max(Duration::from_secs(8)) + .build(), ) - .await; + .build_with_conn(https); - return Ok(LLMChatGeneration { - created: utils::now(), - provider: ProviderID::Anthropic.to_string(), - model: self.id.clone(), - completions: vec![ChatMessage { - role: ChatMessageRole::Assistant, - content: Some(response?.completion.clone()), - name: None, - function_call: None, - }], - }); + let mut stream = client.stream(); + + let mut final_response: Option = None; + 'stream: loop { + match stream.try_next().await { + Ok(stream_next) => match stream_next { + Some(es::SSE::Comment(comment)) => { + println!("UNEXPECTED COMMENT {}", comment); + } + Some(es::SSE::Event(event)) => match event.event_type.as_str() { + "message_start" => { + let event: StreamMessageStart = + match serde_json::from_str(event.data.as_str()) { + Ok(event) => event, + Err(error) => { + Err(anyhow!( + "Error parsing response from Anthropic: {:?} {:?}", + error, + event.data + ))?; + break 'stream; + } + }; + + final_response = Some(event.message.clone()); + } + "content_block_start" => { + let event: StreamContentBlockStart = + match serde_json::from_str(event.data.as_str()) { + Ok(event) => event, + Err(error) => { + Err(anyhow!( + "Error parsing response from Anthropic: {:?} {:?}", + error, + event.data + ))?; + break 'stream; + } + }; + + match final_response.as_mut() { + None => { + Err(anyhow!( + "Error streaming from Anthropic: \ + missing `message_start`" + ))?; + break 'stream; + } + Some(response) => { + response.content.push(event.content_block.clone()); + if event.content_block.text.len() > 0 { + let _ = event_sender.send(json!({ + "type": "tokens", + "content": { + "text": event.content_block.text, + } + + })); + } + } + } + } + "content_block_delta" => { + let event: StreamContentBlockDelta = + match serde_json::from_str(event.data.as_str()) { + Ok(event) => event, + Err(error) => { + Err(anyhow!( + "Error parsing response from Anthropic: {:?} {:?}", + error, + event.data + ))?; + break 'stream; + } + }; + + match event.delta.r#type.as_str() { + "text_delta" => (), + _ => { + Err(anyhow!( + "Error streaming from Anthropic: \ + unexpected delta type: {:?}", + event.delta.r#type + ))?; + break 'stream; + } + } + + match final_response.as_mut() { + None => { + Err(anyhow!( + "Error streaming from Anthropic: \ + missing `message_start`" + ))?; + break 'stream; + } + Some(response) => match response.content.get_mut(0) { + None => { + Err(anyhow!( + "Error streaming from Anthropic: \ + missing `content_block_start`" + ))?; + break 'stream; + } + Some(content) => { + content.text.push_str(event.delta.text.as_str()); + if event.delta.text.len() > 0 { + let _ = event_sender.send(json!({ + "type": "tokens", + "content": { + "text": event.delta.text, + } + + })); + } + } + }, + } + } + "content_block_stop" => { + let _: StreamContentBlockStop = + match serde_json::from_str(event.data.as_str()) { + Ok(event) => event, + Err(error) => { + Err(anyhow!( + "Error parsing response from Anthropic: {:?} {:?}", + error, + event.data + ))?; + break 'stream; + } + }; + } + "message_delta" => { + let event: StreamMessageDelta = + match serde_json::from_str(event.data.as_str()) { + Ok(event) => event, + Err(error) => { + Err(anyhow!( + "Error parsing response from Anthropic: {:?} {:?}", + error, + event.data + ))?; + break 'stream; + } + }; + + match final_response.as_mut() { + None => { + Err(anyhow!( + "Error streaming from Anthropic: \ + missing `message_start`" + ))?; + break 'stream; + } + Some(response) => { + response.stop_reason = event.delta.stop_reason; + response.usage.output_tokens = event.usage.output_tokens; + } + } + } + "message_stop" => { + break 'stream; + } + "error" => { + let event: Error = match serde_json::from_str(event.data.as_str()) { + Ok(event) => event, + Err(_) => { + Err(anyhow!( + "Streaming error from Anthropic: {:?}", + event.data + ))?; + break 'stream; + } + }; + + Err(ModelError { + message: format!( + "Anthropic API Error: {}", + event.error.to_string() + ), + retryable: None, + })?; + break 'stream; + } + _ => (), + }, + None => { + println!("UNEXPECTED NONE"); + break 'stream; + } + }, + Err(error) => { + Err(anyhow!("Error streaming from Anthropic: {:?}", error))?; + break 'stream; + } + } + } + + match final_response { + Some(response) => Ok(response), + None => Err(anyhow!("No response from Anthropic")), + } } pub async fn streamed_completion( &self, - api_key: String, prompt: &str, max_tokens_to_sample: i32, temperature: f32, @@ -226,9 +570,11 @@ impl AnthropicLLM { top_k: Option, stop: &Vec, event_sender: UnboundedSender, - ) -> Result { + ) -> Result { + assert!(self.api_key.is_some()); + let https = HttpsConnector::new(); - let url = self.uri()?.to_string(); + let url = self.completions_uri()?.to_string(); let mut builder = match es::ClientBuilder::for_url(url.as_str()) { Ok(builder) => builder, @@ -245,7 +591,7 @@ impl AnthropicLLM { Ok(builder) => builder, Err(e) => return Err(anyhow!("Error setting header: {:?}", e)), }; - builder = match builder.header("X-API-Key", api_key.as_str()) { + builder = match builder.header("X-API-Key", self.api_key.clone().unwrap().as_str()) { Ok(builder) => builder, Err(e) => return Err(anyhow!("Error setting header: {:?}", e)), }; @@ -282,7 +628,7 @@ impl AnthropicLLM { let mut stream = client.stream(); - let mut final_response: Option = None; + let mut final_response: Option = None; let mut completion = String::new(); 'stream: loop { match stream.try_next().await { @@ -293,22 +639,22 @@ impl AnthropicLLM { Some(es::SSE::Event(event)) => match event.event_type.as_str() { "completion" => { // println!("RESPONSE {} {}", event.event_type, event.data); - let response: Response = match serde_json::from_str(event.data.as_str()) - { - Ok(response) => response, - Err(error) => { - Err(anyhow!( - "Error parsing response from Anthropic: {:?} {:?}", - error, - event.data - ))?; - break 'stream; - } - }; + let response: CompletionResponse = + match serde_json::from_str(event.data.as_str()) { + Ok(response) => response, + Err(error) => { + Err(anyhow!( + "Error parsing response from Anthropic: {:?} {:?}", + error, + event.data + ))?; + break 'stream; + } + }; match response.stop_reason { Some(stop_reason) => { - final_response = Some(Response { + final_response = Some(CompletionResponse { completion, stop_reason: Some(stop_reason), stop: response.stop.clone(), @@ -358,18 +704,19 @@ impl AnthropicLLM { async fn completion( &self, - api_key: String, prompt: &str, max_tokens_to_sample: i32, temperature: f32, top_p: f32, top_k: Option, stop: &Vec, - ) -> Result { + ) -> Result { + assert!(self.api_key.is_some()); + let res = reqwest::Client::new() - .post(self.uri()?.to_string()) + .post(self.completions_uri()?.to_string()) .header("Content-Type", "application/json") - .header("X-API-Key", api_key) + .header("X-API-Key", self.api_key.clone().unwrap()) .header("anthropic-version", "2023-06-01") .json(&json!({ "model": self.id.clone(), @@ -396,7 +743,7 @@ impl AnthropicLLM { let c: &[u8] = &b; let response = match status { reqwest::StatusCode::OK => { - let response: Response = serde_json::from_slice(c)?; + let response: CompletionResponse = serde_json::from_slice(c)?; Ok(response) } _ => { @@ -407,6 +754,7 @@ impl AnthropicLLM { }) } }?; + Ok(response) } } @@ -436,7 +784,7 @@ impl LLM for AnthropicLLM { } fn context_size(&self) -> usize { - if self.id.starts_with("claude-2.1") { + if self.id.starts_with("claude-2.1") || self.id.starts_with("claude-3") { 200000 } else { 100000 @@ -470,7 +818,7 @@ impl LLM for AnthropicLLM { let tokens = self.encode(prompt).await?; max_tokens = Some(std::cmp::min( (self.context_size() - tokens.len()) as i32, - 16384, + 4096, )); } } @@ -480,7 +828,6 @@ impl LLM for AnthropicLLM { let mut completions: Vec = vec![]; let response = match self .streamed_completion( - self.api_key.clone().unwrap(), prompt, match max_tokens { Some(m) => m, @@ -518,11 +865,10 @@ impl LLM for AnthropicLLM { // so we loop here and make n API calls let response = self .completion( - self.api_key.clone().unwrap(), prompt, match max_tokens { Some(m) => m, - None => 256, + None => 4096, }, temperature, match top_p { @@ -582,12 +928,14 @@ impl LLM for AnthropicLLM { top_p: Option, n: usize, stop: &Vec, - max_tokens: Option, + mut max_tokens: Option, _presence_penalty: Option, _frequency_penalty: Option, _extras: Option, event_sender: Option>, ) -> Result { + assert!(self.api_key.is_some()); + assert!(n > 0); if n > 1 { return Err(anyhow!( "Anthropic only supports generating one sample at a time." @@ -597,36 +945,119 @@ impl LLM for AnthropicLLM { return Err(anyhow!("Anthropic does not support chat functions.")); } - match event_sender { + if let Some(m) = max_tokens { + if m == -1 { + max_tokens = Some(4096); + } + } + + let system = match messages.get(0) { + Some(cm) => match cm.role { + ChatMessageRole::System => match cm.content.as_ref() { + Some(c) => Some(c.clone()), + None => None, + }, + _ => None, + }, + None => None, + }; + + let mut messages = messages + .iter() + .skip(match system.as_ref() { + Some(_) => 1, + None => 0, + }) + .map(|cm| AnthropicChatMessage::try_from(cm)) + .collect::>>()?; + + messages = messages.iter().fold( + vec![], + |mut acc: Vec, cm: &AnthropicChatMessage| { + match acc.last_mut() { + Some(last) if last.role == cm.role => { + last.content.extend(cm.content.clone()); + } + _ => { + acc.push(cm.clone()); + } + }; + acc + }, + ); + + messages = messages + .iter() + .map(|cm| AnthropicChatMessage { + content: vec![AnthropicContent { + r#type: String::from("text"), + text: cm + .content + .iter() + .map(|c| c.text.clone()) + .collect::>() + .join("\n"), + }], + role: cm.role.clone(), + }) + .collect(); + + // merge messages of the same role + + let c = match event_sender { Some(es) => { - return self - .streamed_chat_completion( - messages, - temperature, - match top_p { - Some(p) => p, - None => 1.0, - }, - stop, - max_tokens, - es, - ) - .await; + self.streamed_chat_completion( + system, + &messages, + temperature, + match top_p { + Some(p) => p, + None => 1.0, + }, + stop, + match max_tokens { + Some(m) => m, + None => 4096, + }, + es, + ) + .await? } None => { - return self - .chat_completion( - messages, - temperature, - match top_p { - Some(p) => p, - None => 1.0, - }, - stop, - max_tokens, - ) - .await; + self.chat_completion( + system, + &messages, + temperature, + match top_p { + Some(p) => p, + None => 1.0, + }, + stop, + match max_tokens { + Some(m) => m, + None => 4096, + }, + ) + .await? } + }; + + match c.content.first() { + None => Err(anyhow!("No content in response from Anthropic.")), + Some(content) => match content.r#type.as_str() { + "text" => Ok(LLMChatGeneration { + created: utils::now(), + provider: ProviderID::Anthropic.to_string(), + model: self.id.clone(), + completions: vec![ChatMessage { + role: ChatMessageRole::Assistant, + content: Some(content.text.clone()), + name: None, + function_call: None, + }], + }), + _ => Err(anyhow!("Anthropic returned an unexpected content type.")), + }, } } } diff --git a/core/src/providers/mistral.rs b/core/src/providers/mistral.rs index 38c4c2a2d536..6547e6649d17 100644 --- a/core/src/providers/mistral.rs +++ b/core/src/providers/mistral.rs @@ -208,7 +208,7 @@ impl MistralAILLM { messages: &Vec, temperature: f32, top_p: f32, - max_tokens: i32, + max_tokens: Option, event_sender: Option>, ) -> Result { let url = uri.to_string(); @@ -373,11 +373,13 @@ impl MistralAILLM { finish_reason: None, }) .collect::>(), - // The `created` timestamp is absent in the initial stream chunk (in ms), defaulting to the current time (in seconds). + // The `created` timestamp is absent in the initial stream chunk (in ms), + // defaulting to the current time (in seconds). created: f.created.map(|s| s * 1000).unwrap_or_else(now), id: f.id.clone(), model: f.model, - // The `object` field defaults to "start" when not present in the initial stream chunk. + // The `object` field defaults to "start" when not present in the initial stream + // chunk. object: f.object.unwrap_or(String::from("start")), usage: None, }; @@ -444,7 +446,7 @@ impl MistralAILLM { messages: &Vec, temperature: f32, top_p: f32, - max_tokens: i32, + max_tokens: Option, ) -> Result { let mut body = json!({ "messages": messages, @@ -581,16 +583,9 @@ impl LLM for MistralAILLM { } // If max_tokens is not set or is -1, compute the max tokens based on the first message. - let first_message = &messages[0]; let computed_max_tokens = match max_tokens.unwrap_or(-1) { - -1 => match &first_message.content { - Some(content) => { - let tokens = self.encode(content).await?; - (self.context_size() - tokens.len()) as i32 - } - None => self.context_size() as i32, - }, - _ => max_tokens.unwrap(), + -1 => None, + _ => max_tokens, }; // TODO(flav): Handle `extras`. @@ -609,7 +604,10 @@ impl LLM for MistralAILLM { Some(t) => t, None => 1.0, }, - computed_max_tokens, + match max_tokens { + Some(-1) => None, + _ => max_tokens, + }, event_sender, ) .await? diff --git a/front/components/assistant_builder/InstructionScreen.tsx b/front/components/assistant_builder/InstructionScreen.tsx index 7bfe8e8d4c04..659e31bbcf27 100644 --- a/front/components/assistant_builder/InstructionScreen.tsx +++ b/front/components/assistant_builder/InstructionScreen.tsx @@ -22,6 +22,7 @@ import type { } from "@dust-tt/types"; import type { WorkspaceType } from "@dust-tt/types"; import { + CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG, CLAUDE_DEFAULT_MODEL_CONFIG, CLAUDE_INSTANT_DEFAULT_MODEL_CONFIG, Err, @@ -147,11 +148,12 @@ function AdvancedSettings({ const usedModelConfigs: ModelConfig[] = [ GPT_4_TURBO_MODEL_CONFIG, GPT_3_5_TURBO_MODEL_CONFIG, + CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG, CLAUDE_DEFAULT_MODEL_CONFIG, CLAUDE_INSTANT_DEFAULT_MODEL_CONFIG, + MISTRAL_LARGE_MODEL_CONFIG, MISTRAL_MEDIUM_MODEL_CONFIG, MISTRAL_SMALL_MODEL_CONFIG, - MISTRAL_LARGE_MODEL_CONFIG, GEMINI_PRO_DEFAULT_MODEL_CONFIG, ]; diff --git a/front/lib/api/assistant/global_agents.ts b/front/lib/api/assistant/global_agents.ts index ff12a4f90a70..a9611bfc0cb1 100644 --- a/front/lib/api/assistant/global_agents.ts +++ b/front/lib/api/assistant/global_agents.ts @@ -11,14 +11,14 @@ import type { } from "@dust-tt/types"; import type { GlobalAgentStatus } from "@dust-tt/types"; import { - GEMINI_PRO_DEFAULT_MODEL_CONFIG, - GPT_4_MODEL_CONFIG, - GPT_4_TURBO_MODEL_CONFIG, -} from "@dust-tt/types"; -import { + CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG, + CLAUDE_3_SONNET_DEFAULT_MODEL_CONFIG, CLAUDE_DEFAULT_MODEL_CONFIG, CLAUDE_INSTANT_DEFAULT_MODEL_CONFIG, + GEMINI_PRO_DEFAULT_MODEL_CONFIG, GPT_3_5_TURBO_MODEL_CONFIG, + GPT_4_MODEL_CONFIG, + GPT_4_TURBO_MODEL_CONFIG, MISTRAL_LARGE_MODEL_CONFIG, MISTRAL_MEDIUM_MODEL_CONFIG, MISTRAL_SMALL_MODEL_CONFIG, @@ -186,7 +186,7 @@ async function _getClaudeInstantGlobalAgent({ }: { settings: GlobalAgentSettings | null; }): Promise { - const status = settings ? settings.status : "active"; + const status = settings ? settings.status : "disabled_by_admin"; return { id: -1, sId: GLOBAL_AGENTS_SID.CLAUDE_INSTANT, @@ -212,24 +212,28 @@ async function _getClaudeInstantGlobalAgent({ }; } -async function _getClaudeGlobalAgent({ +async function _getClaude2GlobalAgent({ auth, settings, }: { auth: Authenticator; settings: GlobalAgentSettings | null; }): Promise { - const status = !auth.isUpgraded() ? "disabled_free_workspace" : "active"; + let status = settings?.status ?? "disabled_by_admin"; + if (!auth.isUpgraded()) { + status = "disabled_free_workspace"; + } + return { id: -1, - sId: GLOBAL_AGENTS_SID.CLAUDE, + sId: GLOBAL_AGENTS_SID.CLAUDE_2, version: 0, versionCreatedAt: null, versionAuthorId: null, - name: "claude", + name: "claude-2", description: CLAUDE_DEFAULT_MODEL_CONFIG.description, pictureUrl: "https://dust.tt/static/systemavatar/claude_avatar_full.png", - status: settings ? settings.status : status, + status, scope: "global", userListStatus: status === "active" ? "in-list" : "not-in-list", generation: { @@ -245,6 +249,80 @@ async function _getClaudeGlobalAgent({ }; } +async function _getClaude3SonnetGlobalAgent({ + auth, + settings, +}: { + auth: Authenticator; + settings: GlobalAgentSettings | null; +}): Promise { + let status = settings?.status ?? "active"; + if (!auth.isUpgraded()) { + status = "disabled_free_workspace"; + } + + return { + id: -1, + sId: GLOBAL_AGENTS_SID.CLAUDE_3_SONNET, + version: 0, + versionCreatedAt: null, + versionAuthorId: null, + name: "claude-3-sonnet", + description: CLAUDE_3_SONNET_DEFAULT_MODEL_CONFIG.description, + pictureUrl: "https://dust.tt/static/systemavatar/claude_avatar_full.png", + status, + scope: "global", + userListStatus: status === "active" ? "in-list" : "not-in-list", + generation: { + id: -1, + prompt: "", + model: { + providerId: CLAUDE_3_SONNET_DEFAULT_MODEL_CONFIG.providerId, + modelId: CLAUDE_3_SONNET_DEFAULT_MODEL_CONFIG.modelId, + }, + temperature: 0.7, + }, + action: null, + }; +} + +async function _getClaude3OpusGlobalAgent({ + auth, + settings, +}: { + auth: Authenticator; + settings: GlobalAgentSettings | null; +}): Promise { + let status = settings?.status ?? "active"; + if (!auth.isUpgraded()) { + status = "disabled_free_workspace"; + } + + return { + id: -1, + sId: GLOBAL_AGENTS_SID.CLAUDE_3_OPUS, + version: 0, + versionCreatedAt: null, + versionAuthorId: null, + name: "claude-3", + description: CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG.description, + pictureUrl: "https://dust.tt/static/systemavatar/claude_avatar_full.png", + status, + scope: "global", + userListStatus: status === "active" ? "in-list" : "not-in-list", + generation: { + id: -1, + prompt: "", + model: { + providerId: CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG.providerId, + modelId: CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG.modelId, + }, + temperature: 0.7, + }, + action: null, + }; +} + async function _getMistralLargeGlobalAgent({ auth, settings, @@ -351,11 +429,16 @@ async function _getMistralSmallGlobalAgent({ } async function _getGeminiProGlobalAgent({ + auth, settings, }: { + auth: Authenticator; settings: GlobalAgentSettings | null; }): Promise { - const status = settings ? settings.status : "disabled_by_admin"; + let status = settings?.status ?? "disabled_by_admin"; + if (!auth.isUpgraded()) { + status = "disabled_free_workspace"; + } return { id: -1, sId: GLOBAL_AGENTS_SID.GEMINI_PRO, @@ -775,8 +858,17 @@ export async function getGlobalAgent( case GLOBAL_AGENTS_SID.CLAUDE_INSTANT: agentConfiguration = await _getClaudeInstantGlobalAgent({ settings }); break; - case GLOBAL_AGENTS_SID.CLAUDE: - agentConfiguration = await _getClaudeGlobalAgent({ auth, settings }); + case GLOBAL_AGENTS_SID.CLAUDE_3_OPUS: + agentConfiguration = await _getClaude3OpusGlobalAgent({ auth, settings }); + break; + case GLOBAL_AGENTS_SID.CLAUDE_3_SONNET: + agentConfiguration = await _getClaude3SonnetGlobalAgent({ + auth, + settings, + }); + break; + case GLOBAL_AGENTS_SID.CLAUDE_2: + agentConfiguration = await _getClaude2GlobalAgent({ auth, settings }); break; case GLOBAL_AGENTS_SID.MISTRAL_LARGE: agentConfiguration = await _getMistralLargeGlobalAgent({ @@ -794,7 +886,7 @@ export async function getGlobalAgent( agentConfiguration = await _getMistralSmallGlobalAgent({ settings }); break; case GLOBAL_AGENTS_SID.GEMINI_PRO: - agentConfiguration = await _getGeminiProGlobalAgent({ settings }); + agentConfiguration = await _getGeminiProGlobalAgent({ auth, settings }); break; case GLOBAL_AGENTS_SID.SLACK: agentConfiguration = await _getSlackGlobalAgent(auth, { diff --git a/front/lib/assistant.ts b/front/lib/assistant.ts index 59ead9b9622f..7c06225de115 100644 --- a/front/lib/assistant.ts +++ b/front/lib/assistant.ts @@ -44,7 +44,9 @@ export enum GLOBAL_AGENTS_SID { INTERCOM = "intercom", GPT4 = "gpt-4", GPT35_TURBO = "gpt-3.5-turbo", - CLAUDE = "claude-2", + CLAUDE_3_OPUS = "claude-3-opus", + CLAUDE_3_SONNET = "claude-3-sonnet", + CLAUDE_2 = "claude-2", CLAUDE_INSTANT = "claude-instant-1", MISTRAL_LARGE = "mistral-large", MISTRAL_MEDIUM = "mistral-medium", @@ -64,7 +66,9 @@ const CUSTOM_ORDER: string[] = [ GLOBAL_AGENTS_SID.GITHUB, GLOBAL_AGENTS_SID.INTERCOM, GLOBAL_AGENTS_SID.GPT35_TURBO, - GLOBAL_AGENTS_SID.CLAUDE, + GLOBAL_AGENTS_SID.CLAUDE_3_OPUS, + GLOBAL_AGENTS_SID.CLAUDE_3_SONNET, + GLOBAL_AGENTS_SID.CLAUDE_2, GLOBAL_AGENTS_SID.CLAUDE_INSTANT, GLOBAL_AGENTS_SID.MISTRAL_LARGE, GLOBAL_AGENTS_SID.MISTRAL_MEDIUM, diff --git a/front/lib/specification.ts b/front/lib/specification.ts index bbd9daa98ae6..f28ad2ee51b1 100644 --- a/front/lib/specification.ts +++ b/front/lib/specification.ts @@ -198,7 +198,7 @@ export function addBlock( '_fun = (env) => {\n // return [{ role: "user", content: "hi!"}];\n}', functions_code: "_fun = (env) => {\n" + - " // See https://platform.openai.com/docs/guides/gpt/function-calling\n" + + " // See https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models\n" + " // return [{\n" + ' // name: "...",\n' + ' // description: "...",\n' + diff --git a/front/pages/api/w/[wId]/providers/[pId]/models.ts b/front/pages/api/w/[wId]/providers/[pId]/models.ts index 430509d58eee..adf0cdc0b16d 100644 --- a/front/pages/api/w/[wId]/providers/[pId]/models.ts +++ b/front/pages/api/w/[wId]/providers/[pId]/models.ts @@ -136,8 +136,6 @@ async function handler( { id: "command-light" }, { id: "command-nightly" }, { id: "command-light-nightly" }, - { id: "base" }, - { id: "base-light" }, ]; res.status(200).json({ models: cohereModels }); return; @@ -214,11 +212,25 @@ async function handler( return; case "anthropic": - const anthropic_models = [ - { id: "claude-2" }, - { id: "claude-2.1" }, - { id: "claude-instant-1.2" }, - ]; + let anthropic_models: { id: string }[] = []; + if (embed) { + anthropic_models = []; + } else { + if (chat) { + anthropic_models = [ + { id: "claude-instant-1.2" }, + { id: "claude-2.1" }, + { id: "claude-3-sonnet-20240229" }, + { id: "claude-3-opus-20240229" }, + ]; + } else { + anthropic_models = [ + { id: "claude-instant-1.2" }, + { id: "claude-2.1" }, + ]; + } + } + res.status(200).json({ models: anthropic_models }); return; diff --git a/types/src/front/lib/assistant.ts b/types/src/front/lib/assistant.ts index 01c8209f0bb9..14f23cd5f37d 100644 --- a/types/src/front/lib/assistant.ts +++ b/types/src/front/lib/assistant.ts @@ -46,10 +46,36 @@ export const GPT_3_5_TURBO_MODEL_CONFIG = { shortDescription: "OpenAI's fast model.", } as const; +export const CLAUDE_3_OPUS_2024029_MODEL_ID = "claude-3-opus-20240229" as const; +export const CLAUDE_3_SONNET_2024029_MODEL_ID = + "claude-3-sonnet-20240229" as const; export const CLAUDE_2_1_MODEL_ID = "claude-2.1" as const; -export const CLAUDE_2_MODEL_ID = "claude-2" as const; export const CLAUDE_INSTANT_1_2_MODEL_ID = "claude-instant-1.2" as const; +export const CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG = { + providerId: "anthropic" as const, + modelId: CLAUDE_3_OPUS_2024029_MODEL_ID, + displayName: "Claude 3 Opus", + contextSize: 200000, + recommendedTopK: 32, + largeModel: true, + description: + "Anthropic's Claude 3 Opus model, most powerful model for highly complex tasks.", + shortDescription: "Anthropic's powerful model.", +} as const; + +export const CLAUDE_3_SONNET_DEFAULT_MODEL_CONFIG = { + providerId: "anthropic" as const, + modelId: CLAUDE_3_SONNET_2024029_MODEL_ID, + displayName: "Claude 3 Sonnet", + contextSize: 200000, + recommendedTopK: 32, + largeModel: true, + description: + "Anthropic Claude 3 Sonnet model, targeting balance between intelligence and speed for enterprise workloads.", + shortDescription: "Anthropic's balanced model.", +} as const; + export const CLAUDE_DEFAULT_MODEL_CONFIG = { providerId: "anthropic" as const, modelId: CLAUDE_2_1_MODEL_ID, @@ -57,7 +83,7 @@ export const CLAUDE_DEFAULT_MODEL_CONFIG = { contextSize: 200000, recommendedTopK: 32, largeModel: true, - description: "Anthropic's superior performance model (200k context).", + description: "Anthropic's Claude 2 model (200k context).", shortDescription: "Anthropic's smartest model.", } as const; @@ -126,6 +152,8 @@ export const SUPPORTED_MODEL_CONFIGS = [ GPT_3_5_TURBO_MODEL_CONFIG, GPT_4_MODEL_CONFIG, GPT_4_TURBO_MODEL_CONFIG, + CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG, + CLAUDE_3_SONNET_DEFAULT_MODEL_CONFIG, CLAUDE_DEFAULT_MODEL_CONFIG, CLAUDE_INSTANT_DEFAULT_MODEL_CONFIG, MISTRAL_LARGE_MODEL_CONFIG,