From 364bdd93c84c516d55fdf100e71e8e0330970889 Mon Sep 17 00:00:00 2001 From: Henry Fontanier Date: Mon, 23 Dec 2024 17:35:27 +0100 Subject: [PATCH] feat: support o1 and o1-high-reasoning for custom assistants (#9616) Co-authored-by: Henry Fontanier --- core/src/providers/openai.rs | 12 ++-- .../assistant_builder/AssistantBuilder.tsx | 1 - .../assistant_builder/InstructionScreen.tsx | 29 +++----- .../submitAssistantBuilderForm.ts | 2 + front/components/providers/types.ts | 6 ++ front/components/trackers/TrackerBuilder.tsx | 1 - front/lib/api/assistant/configuration.ts | 1 + front/lib/api/assistant/conversation.ts | 21 +++++- front/lib/assistant.ts | 3 +- front/lib/swr/models.ts | 20 ++++++ front/pages/api/w/[wId]/models.ts | 67 +++++++++++++++++++ sdks/js/src/types.ts | 2 + .../internal/agent_configuration.ts | 7 ++ types/src/front/lib/assistant.ts | 14 +++- types/src/shared/feature_flags.ts | 2 + 15 files changed, 160 insertions(+), 28 deletions(-) create mode 100644 front/lib/swr/models.ts create mode 100644 front/pages/api/w/[wId]/models.ts diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index 175fb2945aad..010133546328 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -1835,7 +1835,8 @@ impl LLM for OpenAILLM { Some(self.id.clone()), prompt, max_tokens, - temperature, + // [o1] O1 models do not support custom temperature. + if !model_is_o1 { temperature } else { 1.0 }, n, match top_logprobs { Some(l) => Some(l), @@ -1879,7 +1880,8 @@ impl LLM for OpenAILLM { Some(self.id.clone()), prompt, max_tokens, - temperature, + // [o1] O1 models do not support custom temperature. + if !model_is_o1 { temperature } else { 1.0 }, n, match top_logprobs { Some(l) => Some(l), @@ -2060,7 +2062,8 @@ impl LLM for OpenAILLM { &openai_messages, tools, tool_choice, - temperature, + // [o1] O1 models do not support custom temperature. + if !model_is_o1 { temperature } else { 1.0 }, match top_p { Some(t) => t, None => 1.0, @@ -2091,7 +2094,8 @@ impl LLM for OpenAILLM { &openai_messages, tools, tool_choice, - temperature, + // [o1] O1 models do not support custom temperature. + if !model_is_o1 { temperature } else { 1.0 }, match top_p { Some(t) => t, None => 1.0, diff --git a/front/components/assistant_builder/AssistantBuilder.tsx b/front/components/assistant_builder/AssistantBuilder.tsx index a5bde0b97591..1eada5dd00fc 100644 --- a/front/components/assistant_builder/AssistantBuilder.tsx +++ b/front/components/assistant_builder/AssistantBuilder.tsx @@ -454,7 +454,6 @@ export default function AssistantBuilder({ return ( { export function InstructionScreen({ owner, - plan, builderState, setBuilderState, setEdited, @@ -123,7 +117,6 @@ export function InstructionScreen({ agentConfigurationId, }: { owner: WorkspaceType; - plan: PlanType; builderState: AssistantBuilderState; setBuilderState: ( statefn: (state: AssistantBuilderState) => AssistantBuilderState @@ -325,7 +318,6 @@ export function InstructionScreen({
{ setEdited(true); @@ -400,6 +392,7 @@ function ModelList({ modelConfigs, onClick }: ModelListProps) { onClick({ modelId: modelConfig.modelId, providerId: modelConfig.providerId, + reasoningEffort: modelConfig.reasoningEffort, }); }; @@ -420,17 +413,21 @@ function ModelList({ modelConfigs, onClick }: ModelListProps) { export function AdvancedSettings({ owner, - plan, generationSettings, setGenerationSettings, }: { owner: WorkspaceType; - plan: PlanType; generationSettings: AssistantBuilderState["generationSettings"]; setGenerationSettings: ( generationSettingsSettings: AssistantBuilderState["generationSettings"] ) => void; }) { + const { models, isModelsLoading } = useModels({ owner }); + + if (isModelsLoading) { + return null; + } + const supportedModelConfig = getSupportedModelConfig( generationSettings.modelSettings ); @@ -441,13 +438,7 @@ export function AdvancedSettings({ const bestPerformingModelConfigs: ModelConfigurationType[] = []; const otherModelConfigs: ModelConfigurationType[] = []; - for (const modelConfig of USED_MODEL_CONFIGS) { - if ( - !isProviderWhitelisted(owner, modelConfig.providerId) || - (modelConfig.largeModel && !isUpgraded(plan)) - ) { - continue; - } + for (const modelConfig of models) { if (isBestPerformingModel(modelConfig.modelId)) { bestPerformingModelConfigs.push(modelConfig); } else { diff --git a/front/components/assistant_builder/submitAssistantBuilderForm.ts b/front/components/assistant_builder/submitAssistantBuilderForm.ts index 838ae2280761..07a3c2014204 100644 --- a/front/components/assistant_builder/submitAssistantBuilderForm.ts +++ b/front/components/assistant_builder/submitAssistantBuilderForm.ts @@ -217,6 +217,8 @@ export async function submitAssistantBuilderForm({ modelId: builderState.generationSettings.modelSettings.modelId, providerId: builderState.generationSettings.modelSettings.providerId, temperature: builderState.generationSettings.temperature, + reasoningEffort: + builderState.generationSettings.modelSettings.reasoningEffort, }, maxStepsPerRun, visualizationEnabled: builderState.visualizationEnabled, diff --git a/front/components/providers/types.ts b/front/components/providers/types.ts index 846197d90108..ec7bb86a9968 100644 --- a/front/components/providers/types.ts +++ b/front/components/providers/types.ts @@ -17,6 +17,9 @@ import { MISTRAL_CODESTRAL_MODEL_CONFIG, MISTRAL_LARGE_MODEL_CONFIG, MISTRAL_SMALL_MODEL_CONFIG, + O1_HIGH_REASONING_MODEL_CONFIG, + O1_MINI_MODEL_CONFIG, + O1_MODEL_CONFIG, TOGETHERAI_LLAMA_3_3_70B_INSTRUCT_TURBO_MODEL_CONFIG, TOGETHERAI_QWEN_2_5_CODER_32B_INSTRUCT_MODEL_CONFIG, TOGETHERAI_QWEN_32B_PREVIEW_MODEL_CONFIG, @@ -38,6 +41,9 @@ export const USED_MODEL_CONFIGS: readonly ModelConfig[] = [ GPT_4O_MODEL_CONFIG, GPT_4O_MINI_MODEL_CONFIG, GPT_4_TURBO_MODEL_CONFIG, + O1_MODEL_CONFIG, + O1_MINI_MODEL_CONFIG, + O1_HIGH_REASONING_MODEL_CONFIG, CLAUDE_3_5_SONNET_DEFAULT_MODEL_CONFIG, CLAUDE_3_5_HAIKU_DEFAULT_MODEL_CONFIG, MISTRAL_LARGE_MODEL_CONFIG, diff --git a/front/components/trackers/TrackerBuilder.tsx b/front/components/trackers/TrackerBuilder.tsx index 5400501d092c..df6dd16944a4 100644 --- a/front/components/trackers/TrackerBuilder.tsx +++ b/front/components/trackers/TrackerBuilder.tsx @@ -380,7 +380,6 @@ export const TrackerBuilder = ({ )} m.modelId === supportedModel.modelId && - m.providerId === supportedModel.providerId + m.providerId === supportedModel.providerId && + m.reasoningEffort === supportedModel.reasoningEffort ) as (typeof SUPPORTED_MODEL_CONFIGS)[number]; } diff --git a/front/lib/swr/models.ts b/front/lib/swr/models.ts new file mode 100644 index 000000000000..3c5da217b471 --- /dev/null +++ b/front/lib/swr/models.ts @@ -0,0 +1,20 @@ +import type { LightWorkspaceType } from "@dust-tt/types"; +import type { Fetcher } from "swr"; + +import { fetcher, useSWRWithDefaults } from "@app/lib/swr/swr"; +import type { GetAvailableModelsResponseType } from "@app/pages/api/w/[wId]/models"; + +export function useModels({ owner }: { owner: LightWorkspaceType }) { + const modelsFetcher: Fetcher = fetcher; + + const { data, error } = useSWRWithDefaults( + `/api/w/${owner.sId}/models`, + modelsFetcher + ); + + return { + models: data ? data.models : [], + isModelsLoading: !error && !data, + isModelsError: !!error, + }; +} diff --git a/front/pages/api/w/[wId]/models.ts b/front/pages/api/w/[wId]/models.ts new file mode 100644 index 000000000000..4c7991762f23 --- /dev/null +++ b/front/pages/api/w/[wId]/models.ts @@ -0,0 +1,67 @@ +import type { + ModelConfigurationType, + WithAPIErrorResponse, +} from "@dust-tt/types"; +import { isProviderWhitelisted } from "@dust-tt/types"; +import type { NextApiRequest, NextApiResponse } from "next"; + +import { USED_MODEL_CONFIGS } from "@app/components/providers/types"; +import { withSessionAuthenticationForWorkspace } from "@app/lib/api/auth_wrappers"; +import type { Authenticator } from "@app/lib/auth"; +import { getFeatureFlags } from "@app/lib/auth"; +import { isUpgraded } from "@app/lib/plans/plan_codes"; +import { apiError } from "@app/logger/withlogging"; + +export type GetAvailableModelsResponseType = { + models: ModelConfigurationType[]; +}; + +async function handler( + req: NextApiRequest, + res: NextApiResponse>, + auth: Authenticator +): Promise { + const owner = auth.getNonNullableWorkspace(); + const plan = auth.plan(); + + switch (req.method) { + case "GET": + const featureFlags = await getFeatureFlags(owner); + + const models: ModelConfigurationType[] = []; + for (const m of USED_MODEL_CONFIGS) { + if ( + !isProviderWhitelisted(owner, m.providerId) || + (m.largeModel && !isUpgraded(plan)) + ) { + continue; + } + + if (m.featureFlag && !featureFlags.includes(m.featureFlag)) { + continue; + } + + if ( + m.customAssistantFeatureFlag && + !featureFlags.includes(m.customAssistantFeatureFlag) + ) { + continue; + } + + models.push(m); + } + + return res.status(200).json({ models }); + + default: + return apiError(req, res, { + status_code: 405, + api_error: { + type: "method_not_supported_error", + message: "The method passed is not supported, GET is expected.", + }, + }); + } +} + +export default withSessionAuthenticationForWorkspace(handler); diff --git a/sdks/js/src/types.ts b/sdks/js/src/types.ts index 09f989994224..d6381cf6b48c 100644 --- a/sdks/js/src/types.ts +++ b/sdks/js/src/types.ts @@ -656,6 +656,8 @@ const WhitelistableFeaturesSchema = FlexibleEnumSchema< | "openai_o1_feature" | "openai_o1_mini_feature" | "openai_o1_high_reasoning_feature" + | "openai_o1_custom_assistants_feature" + | "openai_o1_high_reasoning_custom_assistants_feature" | "snowflake_connector_feature" | "index_private_slack_channel" | "conversations_jit_actions" diff --git a/types/src/front/api_handlers/internal/agent_configuration.ts b/types/src/front/api_handlers/internal/agent_configuration.ts index 239affa64373..2a720417c798 100644 --- a/types/src/front/api_handlers/internal/agent_configuration.ts +++ b/types/src/front/api_handlers/internal/agent_configuration.ts @@ -177,6 +177,13 @@ const ModelConfigurationSchema = t.intersection([ }), // TODO(2024-11-04 flav) Clean up this legacy type. t.partial(multiActionsCommonFields), + t.partial({ + reasoningEffort: t.union([ + t.literal("low"), + t.literal("medium"), + t.literal("high"), + ]), + }), ]); const IsSupportedModelSchema = new t.Type( "SupportedModel", diff --git a/types/src/front/lib/assistant.ts b/types/src/front/lib/assistant.ts index cec53808b129..7459630e1b09 100644 --- a/types/src/front/lib/assistant.ts +++ b/types/src/front/lib/assistant.ts @@ -4,6 +4,7 @@ import { } from "../../front/assistant/agent"; import { GenerationTokensEvent } from "../../front/assistant/generation"; import { WorkspaceType } from "../../front/user"; +import { WhitelistableFeature } from "../../shared/feature_flags"; import { ExtractSpecificKeys } from "../../shared/typescipt_utils"; import { ioTsEnum } from "../../shared/utils/iots_utils"; @@ -196,6 +197,9 @@ export type ModelConfigurationType = { // Only used for O-series OpenAI models. reasoningEffort?: AgentReasoningEffort; + + featureFlag?: WhitelistableFeature; + customAssistantFeatureFlag?: WhitelistableFeature; }; // Should be used for all Open AI models older than gpt-4o-2024-08-06 to prevent issues @@ -286,6 +290,8 @@ export const O1_MODEL_CONFIG: ModelConfigurationType = { shortDescription: "OpenAI's reasoning model.", isLegacy: false, supportsVision: true, + featureFlag: "openai_o1_feature", + customAssistantFeatureFlag: "openai_o1_custom_assistants_feature", }; export const O1_HIGH_REASONING_MODEL_CONFIG: ModelConfigurationType = { providerId: "openai", @@ -301,6 +307,9 @@ export const O1_HIGH_REASONING_MODEL_CONFIG: ModelConfigurationType = { isLegacy: false, supportsVision: true, reasoningEffort: "high", + featureFlag: "openai_o1_high_reasoning_feature", + customAssistantFeatureFlag: + "openai_o1_high_reasoning_custom_assistants_feature", }; export const O1_MINI_MODEL_CONFIG: ModelConfigurationType = { providerId: "openai", @@ -315,6 +324,8 @@ export const O1_MINI_MODEL_CONFIG: ModelConfigurationType = { shortDescription: "OpenAI's fast reasoning model.", isLegacy: false, supportsVision: false, + featureFlag: "openai_o1_mini_feature", + customAssistantFeatureFlag: "openai_o1_custom_assistants_feature", }; const ANTHROPIC_DELIMITERS_CONFIGURATION = { @@ -623,6 +634,7 @@ export const SUPPORTED_MODEL_CONFIGS: ModelConfigurationType[] = [ GPT_4O_20240806_MODEL_CONFIG, GPT_4O_MINI_MODEL_CONFIG, O1_MODEL_CONFIG, + O1_HIGH_REASONING_MODEL_CONFIG, O1_MINI_MODEL_CONFIG, CLAUDE_3_OPUS_DEFAULT_MODEL_CONFIG, CLAUDE_3_5_SONNET_20240620_DEPRECATED_MODEL_CONFIG, @@ -649,7 +661,7 @@ export type ModelConfig = (typeof SUPPORTED_MODEL_CONFIGS)[number]; // pairs that are in SUPPORTED_MODELS export type SupportedModel = ExtractSpecificKeys< (typeof SUPPORTED_MODEL_CONFIGS)[number], - "providerId" | "modelId" + "providerId" | "modelId" | "reasoningEffort" >; export function isSupportedModel(model: unknown): model is SupportedModel { diff --git a/types/src/shared/feature_flags.ts b/types/src/shared/feature_flags.ts index c1e489248e0b..bf81d4d7890c 100644 --- a/types/src/shared/feature_flags.ts +++ b/types/src/shared/feature_flags.ts @@ -9,6 +9,8 @@ export const WHITELISTABLE_FEATURES = [ "openai_o1_feature", "openai_o1_mini_feature", "openai_o1_high_reasoning_feature", + "openai_o1_custom_assistants_feature", + "openai_o1_high_reasoning_custom_assistants_feature", "index_private_slack_channel", "conversations_jit_actions", "labs_trackers",