Skip to content

Commit

Permalink
Add support for Anthropic vision (#6324)
Browse files Browse the repository at this point in the history
* Add support for Anthropic vision

* Update Anthropic models with Vision support

* Use lifetime
  • Loading branch information
flvndvd authored Jul 18, 2024
1 parent f9a8ce1 commit 6b5271e
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 8 deletions.
120 changes: 115 additions & 5 deletions core/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ use crate::utils;
use crate::utils::ParseError;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine};
use eventsource_client as es;
use eventsource_client::Client as ESClient;
use futures::TryStreamExt;
use hyper::StatusCode;
use hyper::{body::Buf, Uri};
use reqwest::get;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::fmt::{self, Display};
use std::io::prelude::*;
use std::str::FromStr;
Expand Down Expand Up @@ -74,10 +77,18 @@ struct AnthropicContentToolUse {
input: Value,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
struct AnthropicImageContent {
r#type: String,
media_type: String,
data: String,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AnthropicContentType {
Text,
Image,
ToolUse,
ToolResult,
}
Expand All @@ -94,6 +105,9 @@ struct AnthropicContent {

#[serde(skip_serializing_if = "Option::is_none", flatten)]
tool_result: Option<AnthropicContentToolResult>,

#[serde(skip_serializing_if = "Option::is_none")]
source: Option<AnthropicImageContent>,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
Expand Down Expand Up @@ -182,10 +196,82 @@ struct AnthropicChatMessage {
role: AnthropicChatMessageRole,
}

impl TryFrom<&ChatMessage> for AnthropicChatMessage {
async fn fetch_image_base64(image_url: &str) -> Result<(String, AnthropicImageContent)> {
let response = get(image_url)
.await
.map_err(|e| anyhow!("Invalid image: {}", e))?;

let mime_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|ct| ct.to_str().ok())
.unwrap_or("application/octet-stream") // Default to a general binary type if MIME type is not found.
.to_string();

let bytes = response
.bytes()
.await
.map_err(|e| anyhow!("Invalid image, could not parse response {}", e))?;

Ok((
image_url.to_string(),
AnthropicImageContent {
r#type: "base64".to_string(),
media_type: mime_type,
data: general_purpose::STANDARD.encode(&bytes),
},
))
}

async fn fetch_and_encode_images(
messages: Vec<ChatMessage>,
) -> Result<HashMap<String, AnthropicImageContent>, anyhow::Error> {
let futures = messages
.into_iter()
.filter_map(|message| {
if let ChatMessage::User(user_msg) = message {
if let ContentBlock::Mixed(mixed_content) = user_msg.content {
let inner_futures = mixed_content
.into_iter()
.filter_map(|content| {
if let MixedContent::ImageContent(ic) = content {
let url = ic.image_url.url.clone();

Some(async move { fetch_image_base64(&url).await })
} else {
None
}
})
.collect::<Vec<_>>();
return Some(inner_futures);
}
}
None
})
.flatten()
.collect::<Vec<_>>();

let base64_pairs = futures::future::try_join_all(futures)
.await?
.into_iter()
.map(|(image_url, img_content)| (image_url.clone(), img_content))
.collect::<HashMap<_, _>>();

Ok(base64_pairs)
}

struct ChatMessageConversionInput<'a> {
chat_message: &'a ChatMessage,
base64_map: &'a HashMap<String, AnthropicImageContent>,
}

impl<'a> TryFrom<&'a ChatMessageConversionInput<'a>> for AnthropicChatMessage {
type Error = anyhow::Error;

fn try_from(cm: &ChatMessage) -> Result<Self, Self::Error> {
fn try_from(input: &ChatMessageConversionInput) -> Result<Self, Self::Error> {
let cm = input.chat_message;
let base64_map = input.base64_map;

match cm {
ChatMessage::Assistant(assistant_msg) => {
// Handling tool_uses.
Expand All @@ -204,6 +290,7 @@ impl TryFrom<&ChatMessage> for AnthropicChatMessage {
input: value,
}),
tool_result: None,
source: None,
})
})
.collect::<Result<Vec<AnthropicContent>>>()?,
Expand All @@ -217,6 +304,7 @@ impl TryFrom<&ChatMessage> for AnthropicChatMessage {
text: Some(text.clone()),
tool_result: None,
tool_use: None,
source: None,
});

// Combining all content into one vector using iterators.
Expand All @@ -241,6 +329,7 @@ impl TryFrom<&ChatMessage> for AnthropicChatMessage {
content: Some(function_msg.content.clone()),
}),
text: None,
source: None,
};

Ok(AnthropicChatMessage {
Expand All @@ -258,9 +347,20 @@ impl TryFrom<&ChatMessage> for AnthropicChatMessage {
text: Some(tc.text.clone()),
tool_result: None,
tool_use: None,
source: None,
}),
MixedContent::ImageContent(_) => {
Err(anyhow!("Vision is not supported for Anthropic."))
MixedContent::ImageContent(ic) => {
if let Some(base64_data) = base64_map.get(&ic.image_url.url) {
Ok(AnthropicContent {
r#type: AnthropicContentType::Image,
source: Some(base64_data.clone()),
text: None,
tool_use: None,
tool_result: None,
})
} else {
Err(anyhow!("Invalid Image."))
}
}
})
.collect::<Result<Vec<AnthropicContent>>>()?;
Expand All @@ -276,6 +376,7 @@ impl TryFrom<&ChatMessage> for AnthropicChatMessage {
text: Some(t.clone()),
tool_result: None,
tool_use: None,
source: None,
}],
role: AnthropicChatMessageRole::User,
}),
Expand All @@ -286,6 +387,7 @@ impl TryFrom<&ChatMessage> for AnthropicChatMessage {
text: Some(system_msg.content.clone()),
tool_result: None,
tool_use: None,
source: None,
}],
role: AnthropicChatMessageRole::User,
}),
Expand Down Expand Up @@ -1566,13 +1668,21 @@ impl LLM for AnthropicLLM {
None => None,
};

let base64_map = fetch_and_encode_images(messages.clone()).await?;
let mut messages = messages
.iter()
.skip(match system.as_ref() {
Some(_) => 1,
None => 0,
})
.map(|cm| AnthropicChatMessage::try_from(cm))
.map(|cm| {
let conversion_input = ChatMessageConversionInput {
chat_message: &cm,
base64_map: &base64_map,
};

AnthropicChatMessage::try_from(&conversion_input)
})
.collect::<Result<Vec<AnthropicChatMessage>>>()?;

// Group consecutive messages with the same role by appending their content. This is
Expand Down
6 changes: 3 additions & 3 deletions types/src/front/lib/assistant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ export const CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG: ModelConfigurationType = {
shortDescription: "Anthropic's largest model.",
isLegacy: false,
delimitersConfiguration: ANTHROPIC_DELIMITERS_CONFIGURATION,
supportsVision: false,
supportsVision: true,
};
export const CLAUDE_3_5_SONNET_DEFAULT_MODEL_CONFIG: ModelConfigurationType = {
providerId: "anthropic",
Expand All @@ -271,7 +271,7 @@ export const CLAUDE_3_5_SONNET_DEFAULT_MODEL_CONFIG: ModelConfigurationType = {
shortDescription: "Anthropic's latest model.",
isLegacy: false,
delimitersConfiguration: ANTHROPIC_DELIMITERS_CONFIGURATION,
supportsVision: false,
supportsVision: true,
};
export const CLAUDE_3_HAIKU_DEFAULT_MODEL_CONFIG: ModelConfigurationType = {
providerId: "anthropic",
Expand All @@ -285,7 +285,7 @@ export const CLAUDE_3_HAIKU_DEFAULT_MODEL_CONFIG: ModelConfigurationType = {
"Anthropic's Claude 3 Haiku model, cost effective and high throughput (200k context).",
shortDescription: "Anthropic's cost-effective model.",
isLegacy: false,
supportsVision: false,
supportsVision: true,
};
export const CLAUDE_2_DEFAULT_MODEL_CONFIG: ModelConfigurationType = {
providerId: "anthropic",
Expand Down

0 comments on commit 6b5271e

Please sign in to comment.