Skip to content

Commit

Permalink
Anthropic: Chat completion & Claude 3 (Opus) (#4180)
Browse files Browse the repository at this point in the history
* working chat completion

* Doc for function calling

* Anthropic: chat streaming

* working streaming

* nits

* Add Claude3 to front

* chat/non chat models anthropic

* more tweaks

* lint

* lint

* remove logs
  • Loading branch information
spolu authored Mar 7, 2024
1 parent 918d641 commit f2cff65
Show file tree
Hide file tree
Showing 8 changed files with 780 additions and 213 deletions.
773 changes: 602 additions & 171 deletions core/src/providers/anthropic.rs

Large diffs are not rendered by default.

26 changes: 12 additions & 14 deletions core/src/providers/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl MistralAILLM {
messages: &Vec<ChatMessage>,
temperature: f32,
top_p: f32,
max_tokens: i32,
max_tokens: Option<i32>,
event_sender: Option<UnboundedSender<Value>>,
) -> Result<ChatCompletion> {
let url = uri.to_string();
Expand Down Expand Up @@ -373,11 +373,13 @@ impl MistralAILLM {
finish_reason: None,
})
.collect::<Vec<_>>(),
// The `created` timestamp is absent in the initial stream chunk (in ms), defaulting to the current time (in seconds).
// The `created` timestamp is absent in the initial stream chunk (in ms),
// defaulting to the current time (in seconds).
created: f.created.map(|s| s * 1000).unwrap_or_else(now),
id: f.id.clone(),
model: f.model,
// The `object` field defaults to "start" when not present in the initial stream chunk.
// The `object` field defaults to "start" when not present in the initial stream
// chunk.
object: f.object.unwrap_or(String::from("start")),
usage: None,
};
Expand Down Expand Up @@ -444,7 +446,7 @@ impl MistralAILLM {
messages: &Vec<ChatMessage>,
temperature: f32,
top_p: f32,
max_tokens: i32,
max_tokens: Option<i32>,
) -> Result<ChatCompletion> {
let mut body = json!({
"messages": messages,
Expand Down Expand Up @@ -581,16 +583,9 @@ impl LLM for MistralAILLM {
}

// If max_tokens is not set or is -1, compute the max tokens based on the first message.
let first_message = &messages[0];
let computed_max_tokens = match max_tokens.unwrap_or(-1) {
-1 => match &first_message.content {
Some(content) => {
let tokens = self.encode(content).await?;
(self.context_size() - tokens.len()) as i32
}
None => self.context_size() as i32,
},
_ => max_tokens.unwrap(),
-1 => None,
_ => max_tokens,
};

// TODO(flav): Handle `extras`.
Expand All @@ -609,7 +604,10 @@ impl LLM for MistralAILLM {
Some(t) => t,
None => 1.0,
},
computed_max_tokens,
match max_tokens {
Some(-1) => None,
_ => max_tokens,
},
event_sender,
)
.await?
Expand Down
4 changes: 3 additions & 1 deletion front/components/assistant_builder/InstructionScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import type {
} from "@dust-tt/types";
import type { WorkspaceType } from "@dust-tt/types";
import {
CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG,
CLAUDE_DEFAULT_MODEL_CONFIG,
CLAUDE_INSTANT_DEFAULT_MODEL_CONFIG,
Err,
Expand Down Expand Up @@ -147,11 +148,12 @@ function AdvancedSettings({
const usedModelConfigs: ModelConfig[] = [
GPT_4_TURBO_MODEL_CONFIG,
GPT_3_5_TURBO_MODEL_CONFIG,
CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG,
CLAUDE_DEFAULT_MODEL_CONFIG,
CLAUDE_INSTANT_DEFAULT_MODEL_CONFIG,
MISTRAL_LARGE_MODEL_CONFIG,
MISTRAL_MEDIUM_MODEL_CONFIG,
MISTRAL_SMALL_MODEL_CONFIG,
MISTRAL_LARGE_MODEL_CONFIG,
GEMINI_PRO_DEFAULT_MODEL_CONFIG,
];

Expand Down
122 changes: 107 additions & 15 deletions front/lib/api/assistant/global_agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import type {
} from "@dust-tt/types";
import type { GlobalAgentStatus } from "@dust-tt/types";
import {
GEMINI_PRO_DEFAULT_MODEL_CONFIG,
GPT_4_MODEL_CONFIG,
GPT_4_TURBO_MODEL_CONFIG,
} from "@dust-tt/types";
import {
CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG,
CLAUDE_3_SONNET_DEFAULT_MODEL_CONFIG,
CLAUDE_DEFAULT_MODEL_CONFIG,
CLAUDE_INSTANT_DEFAULT_MODEL_CONFIG,
GEMINI_PRO_DEFAULT_MODEL_CONFIG,
GPT_3_5_TURBO_MODEL_CONFIG,
GPT_4_MODEL_CONFIG,
GPT_4_TURBO_MODEL_CONFIG,
MISTRAL_LARGE_MODEL_CONFIG,
MISTRAL_MEDIUM_MODEL_CONFIG,
MISTRAL_SMALL_MODEL_CONFIG,
Expand Down Expand Up @@ -186,7 +186,7 @@ async function _getClaudeInstantGlobalAgent({
}: {
settings: GlobalAgentSettings | null;
}): Promise<AgentConfigurationType> {
const status = settings ? settings.status : "active";
const status = settings ? settings.status : "disabled_by_admin";
return {
id: -1,
sId: GLOBAL_AGENTS_SID.CLAUDE_INSTANT,
Expand All @@ -212,24 +212,28 @@ async function _getClaudeInstantGlobalAgent({
};
}

async function _getClaudeGlobalAgent({
async function _getClaude2GlobalAgent({
auth,
settings,
}: {
auth: Authenticator;
settings: GlobalAgentSettings | null;
}): Promise<AgentConfigurationType> {
const status = !auth.isUpgraded() ? "disabled_free_workspace" : "active";
let status = settings?.status ?? "disabled_by_admin";
if (!auth.isUpgraded()) {
status = "disabled_free_workspace";
}

return {
id: -1,
sId: GLOBAL_AGENTS_SID.CLAUDE,
sId: GLOBAL_AGENTS_SID.CLAUDE_2,
version: 0,
versionCreatedAt: null,
versionAuthorId: null,
name: "claude",
name: "claude-2",
description: CLAUDE_DEFAULT_MODEL_CONFIG.description,
pictureUrl: "https://dust.tt/static/systemavatar/claude_avatar_full.png",
status: settings ? settings.status : status,
status,
scope: "global",
userListStatus: status === "active" ? "in-list" : "not-in-list",
generation: {
Expand All @@ -245,6 +249,80 @@ async function _getClaudeGlobalAgent({
};
}

async function _getClaude3SonnetGlobalAgent({
auth,
settings,
}: {
auth: Authenticator;
settings: GlobalAgentSettings | null;
}): Promise<AgentConfigurationType> {
let status = settings?.status ?? "active";
if (!auth.isUpgraded()) {
status = "disabled_free_workspace";
}

return {
id: -1,
sId: GLOBAL_AGENTS_SID.CLAUDE_3_SONNET,
version: 0,
versionCreatedAt: null,
versionAuthorId: null,
name: "claude-3-sonnet",
description: CLAUDE_3_SONNET_DEFAULT_MODEL_CONFIG.description,
pictureUrl: "https://dust.tt/static/systemavatar/claude_avatar_full.png",
status,
scope: "global",
userListStatus: status === "active" ? "in-list" : "not-in-list",
generation: {
id: -1,
prompt: "",
model: {
providerId: CLAUDE_3_SONNET_DEFAULT_MODEL_CONFIG.providerId,
modelId: CLAUDE_3_SONNET_DEFAULT_MODEL_CONFIG.modelId,
},
temperature: 0.7,
},
action: null,
};
}

async function _getClaude3OpusGlobalAgent({
auth,
settings,
}: {
auth: Authenticator;
settings: GlobalAgentSettings | null;
}): Promise<AgentConfigurationType> {
let status = settings?.status ?? "active";
if (!auth.isUpgraded()) {
status = "disabled_free_workspace";
}

return {
id: -1,
sId: GLOBAL_AGENTS_SID.CLAUDE_3_OPUS,
version: 0,
versionCreatedAt: null,
versionAuthorId: null,
name: "claude-3",
description: CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG.description,
pictureUrl: "https://dust.tt/static/systemavatar/claude_avatar_full.png",
status,
scope: "global",
userListStatus: status === "active" ? "in-list" : "not-in-list",
generation: {
id: -1,
prompt: "",
model: {
providerId: CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG.providerId,
modelId: CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG.modelId,
},
temperature: 0.7,
},
action: null,
};
}

async function _getMistralLargeGlobalAgent({
auth,
settings,
Expand Down Expand Up @@ -351,11 +429,16 @@ async function _getMistralSmallGlobalAgent({
}

async function _getGeminiProGlobalAgent({
auth,
settings,
}: {
auth: Authenticator;
settings: GlobalAgentSettings | null;
}): Promise<AgentConfigurationType> {
const status = settings ? settings.status : "disabled_by_admin";
let status = settings?.status ?? "disabled_by_admin";
if (!auth.isUpgraded()) {
status = "disabled_free_workspace";
}
return {
id: -1,
sId: GLOBAL_AGENTS_SID.GEMINI_PRO,
Expand Down Expand Up @@ -775,8 +858,17 @@ export async function getGlobalAgent(
case GLOBAL_AGENTS_SID.CLAUDE_INSTANT:
agentConfiguration = await _getClaudeInstantGlobalAgent({ settings });
break;
case GLOBAL_AGENTS_SID.CLAUDE:
agentConfiguration = await _getClaudeGlobalAgent({ auth, settings });
case GLOBAL_AGENTS_SID.CLAUDE_3_OPUS:
agentConfiguration = await _getClaude3OpusGlobalAgent({ auth, settings });
break;
case GLOBAL_AGENTS_SID.CLAUDE_3_SONNET:
agentConfiguration = await _getClaude3SonnetGlobalAgent({
auth,
settings,
});
break;
case GLOBAL_AGENTS_SID.CLAUDE_2:
agentConfiguration = await _getClaude2GlobalAgent({ auth, settings });
break;
case GLOBAL_AGENTS_SID.MISTRAL_LARGE:
agentConfiguration = await _getMistralLargeGlobalAgent({
Expand All @@ -794,7 +886,7 @@ export async function getGlobalAgent(
agentConfiguration = await _getMistralSmallGlobalAgent({ settings });
break;
case GLOBAL_AGENTS_SID.GEMINI_PRO:
agentConfiguration = await _getGeminiProGlobalAgent({ settings });
agentConfiguration = await _getGeminiProGlobalAgent({ auth, settings });
break;
case GLOBAL_AGENTS_SID.SLACK:
agentConfiguration = await _getSlackGlobalAgent(auth, {
Expand Down
8 changes: 6 additions & 2 deletions front/lib/assistant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ export enum GLOBAL_AGENTS_SID {
INTERCOM = "intercom",
GPT4 = "gpt-4",
GPT35_TURBO = "gpt-3.5-turbo",
CLAUDE = "claude-2",
CLAUDE_3_OPUS = "claude-3-opus",
CLAUDE_3_SONNET = "claude-3-sonnet",
CLAUDE_2 = "claude-2",
CLAUDE_INSTANT = "claude-instant-1",
MISTRAL_LARGE = "mistral-large",
MISTRAL_MEDIUM = "mistral-medium",
Expand All @@ -64,7 +66,9 @@ const CUSTOM_ORDER: string[] = [
GLOBAL_AGENTS_SID.GITHUB,
GLOBAL_AGENTS_SID.INTERCOM,
GLOBAL_AGENTS_SID.GPT35_TURBO,
GLOBAL_AGENTS_SID.CLAUDE,
GLOBAL_AGENTS_SID.CLAUDE_3_OPUS,
GLOBAL_AGENTS_SID.CLAUDE_3_SONNET,
GLOBAL_AGENTS_SID.CLAUDE_2,
GLOBAL_AGENTS_SID.CLAUDE_INSTANT,
GLOBAL_AGENTS_SID.MISTRAL_LARGE,
GLOBAL_AGENTS_SID.MISTRAL_MEDIUM,
Expand Down
2 changes: 1 addition & 1 deletion front/lib/specification.ts
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ export function addBlock(
'_fun = (env) => {\n // return [{ role: "user", content: "hi!"}];\n}',
functions_code:
"_fun = (env) => {\n" +
" // See https://platform.openai.com/docs/guides/gpt/function-calling\n" +
" // See https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models\n" +
" // return [{\n" +
' // name: "...",\n' +
' // description: "...",\n' +
Expand Down
26 changes: 19 additions & 7 deletions front/pages/api/w/[wId]/providers/[pId]/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ async function handler(
{ id: "command-light" },
{ id: "command-nightly" },
{ id: "command-light-nightly" },
{ id: "base" },
{ id: "base-light" },
];
res.status(200).json({ models: cohereModels });
return;
Expand Down Expand Up @@ -214,11 +212,25 @@ async function handler(
return;

case "anthropic":
const anthropic_models = [
{ id: "claude-2" },
{ id: "claude-2.1" },
{ id: "claude-instant-1.2" },
];
let anthropic_models: { id: string }[] = [];
if (embed) {
anthropic_models = [];
} else {
if (chat) {
anthropic_models = [
{ id: "claude-instant-1.2" },
{ id: "claude-2.1" },
{ id: "claude-3-sonnet-20240229" },
{ id: "claude-3-opus-20240229" },
];
} else {
anthropic_models = [
{ id: "claude-instant-1.2" },
{ id: "claude-2.1" },
];
}
}

res.status(200).json({ models: anthropic_models });
return;

Expand Down
Loading

0 comments on commit f2cff65

Please sign in to comment.