diff --git a/front/components/assistant/GalleryAssistantPreviewContainer.tsx b/front/components/assistant/GalleryAssistantPreviewContainer.tsx index 3438c35739e3..db48b960e7bd 100644 --- a/front/components/assistant/GalleryAssistantPreviewContainer.tsx +++ b/front/components/assistant/GalleryAssistantPreviewContainer.tsx @@ -10,7 +10,7 @@ import { useContext, useEffect, useState } from "react"; import type { NotificationType } from "@app/components/sparkle/Notification"; import { SendNotificationsContext } from "@app/components/sparkle/Notification"; import { isLargeModel } from "@app/lib/assistant"; -import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes"; +import { isUpgraded } from "@app/lib/plans/plan_codes"; import type { PostAgentListStatusRequestBody } from "@app/pages/api/w/[wId]/members/me/agent_list_status"; type AssistantPreviewFlow = "personal" | "workspace"; @@ -174,7 +174,7 @@ export function GalleryAssistantPreviewContainer({ const isGlobal = scope === "global"; const isAddedToWorkspace = flow === "workspace" && isAdded; - const hasAccessToLargeModels = plan?.code !== FREE_TEST_PLAN_CODE; + const hasAccessToLargeModels = isUpgraded(plan); const eligibleForTesting = hasAccessToLargeModels || !isLargeModel(generation?.model); const isTestable = !isGlobal && !isAdded && eligibleForTesting; diff --git a/front/components/assistant_builder/AssistantBuilder.tsx b/front/components/assistant_builder/AssistantBuilder.tsx index 0b8e45dbdbac..4ad536727bc8 100644 --- a/front/components/assistant_builder/AssistantBuilder.tsx +++ b/front/components/assistant_builder/AssistantBuilder.tsx @@ -68,7 +68,7 @@ import { SendNotificationsContext } from "@app/components/sparkle/Notification"; import { getSupportedModelConfig } from "@app/lib/assistant"; import { CONNECTOR_CONFIGURATIONS } from "@app/lib/connector_providers"; import { isActivatedStructuredDB } from "@app/lib/development"; -import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes"; +import { isUpgraded } from "@app/lib/plans/plan_codes"; import { useSlackChannelsLinkedWithAgent } from "@app/lib/swr"; import { classNames } from "@app/lib/utils"; @@ -308,10 +308,9 @@ export default function AssistantBuilder({ scope: defaultScope, generationSettings: { ...DEFAULT_ASSISTANT_STATE.generationSettings, - modelSettings: - plan.code === FREE_TEST_PLAN_CODE - ? GPT_3_5_TURBO_MODEL_CONFIG - : GPT_4_TURBO_MODEL_CONFIG, + modelSettings: !isUpgraded(plan) + ? GPT_3_5_TURBO_MODEL_CONFIG + : GPT_4_TURBO_MODEL_CONFIG, }, } ); @@ -1532,9 +1531,7 @@ function AdvancedSettings({ {usedModelConfigs - .filter( - (m) => !(m.largeModel && plan.code === FREE_TEST_PLAN_CODE) - ) + .filter((m) => !(m.largeModel && !isUpgraded(plan))) .map((modelConfig) => ( { const owner = auth.workspace(); - const plan = auth.plan(); - if (!owner || !plan) { + if (!owner) { throw new Error("Unexpected unauthenticated call to `runGeneration`"); } @@ -309,7 +307,7 @@ export async function* runGeneration( let model = c.model; - if (isLargeModel(model) && plan.code === FREE_TEST_PLAN_CODE) { + if (isLargeModel(model) && !auth.isUpgraded()) { yield { type: "generation_error", created: Date.now(), diff --git a/front/lib/api/assistant/global_agents.ts b/front/lib/api/assistant/global_agents.ts index 39bce3f88aa9..3ad37d144207 100644 --- a/front/lib/api/assistant/global_agents.ts +++ b/front/lib/api/assistant/global_agents.ts @@ -9,7 +9,7 @@ import type { ConnectorProvider, DataSourceType, } from "@dust-tt/types"; -import type { GlobalAgentStatus, PlanType } from "@dust-tt/types"; +import type { GlobalAgentStatus } from "@dust-tt/types"; import { GEMINI_PRO_DEFAULT_MODEL_CONFIG } from "@dust-tt/types"; import { CLAUDE_DEFAULT_MODEL_CONFIG, @@ -25,7 +25,6 @@ import { GLOBAL_AGENTS_SID } from "@app/lib/assistant"; import type { Authenticator } from "@app/lib/auth"; import { prodAPICredentialsForOwner } from "@app/lib/auth"; import { GlobalAgentSettings } from "@app/lib/models/assistant/agent"; -import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes"; import logger from "@app/logger/logger"; class HelperAssistantPrompt { @@ -84,20 +83,15 @@ async function _getHelperGlobalAgent( if (!owner) { throw new Error("Unexpected `auth` without `workspace`."); } - const plan = auth.plan(); - if (!plan) { - throw new Error("Unexpected `auth` without `plan`."); - } - const model = - plan.code === FREE_TEST_PLAN_CODE - ? { - providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId, - modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId, - } - : { - providerId: GPT_4_TURBO_MODEL_CONFIG.providerId, - modelId: GPT_4_TURBO_MODEL_CONFIG.modelId, - }; + const model = !auth.isUpgraded() + ? { + providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId, + modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId, + } + : { + providerId: GPT_4_TURBO_MODEL_CONFIG.providerId, + modelId: GPT_4_TURBO_MODEL_CONFIG.modelId, + }; return { id: -1, sId: GLOBAL_AGENTS_SID.HELPER, @@ -153,12 +147,11 @@ async function _getGPT35TurboGlobalAgent({ } async function _getGPT4GlobalAgent({ - plan, + auth, }: { - plan: PlanType; + auth: Authenticator; }): Promise { - const status = - plan.code === FREE_TEST_PLAN_CODE ? "disabled_free_workspace" : "active"; + const status = !auth.isUpgraded() ? "disabled_free_workspace" : "active"; return { id: -1, sId: GLOBAL_AGENTS_SID.GPT4, @@ -218,14 +211,13 @@ async function _getClaudeInstantGlobalAgent({ } async function _getClaudeGlobalAgent({ + auth, settings, - plan, }: { + auth: Authenticator; settings: GlobalAgentSettings | null; - plan: PlanType; }): Promise { - const status = - plan.code === FREE_TEST_PLAN_CODE ? "disabled_free_workspace" : "active"; + const status = !auth.isUpgraded() ? "disabled_free_workspace" : "active"; return { id: -1, sId: GLOBAL_AGENTS_SID.CLAUDE, @@ -252,14 +244,14 @@ async function _getClaudeGlobalAgent({ } async function _getMistralMediumGlobalAgent({ - plan, + auth, settings, }: { - plan: PlanType; + auth: Authenticator; settings: GlobalAgentSettings | null; }): Promise { let status = settings?.status ?? "disabled_by_admin"; - if (plan.code === FREE_TEST_PLAN_CODE) { + if (!auth.isUpgraded()) { status = "disabled_free_workspace"; } @@ -378,11 +370,6 @@ async function _getManagedDataSourceAgent( throw new Error("Unexpected `auth` without `workspace`."); } - const plan = auth.plan(); - if (!plan) { - throw new Error("Unexpected `auth` without `plan`."); - } - const prodCredentials = await prodAPICredentialsForOwner(owner); // Check if deactivated by an admin @@ -441,16 +428,15 @@ async function _getManagedDataSourceAgent( generation: { id: -1, prompt, - model: - plan.code === FREE_TEST_PLAN_CODE - ? { - providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId, - modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId, - } - : { - providerId: GPT_4_TURBO_MODEL_CONFIG.providerId, - modelId: GPT_4_TURBO_MODEL_CONFIG.modelId, - }, + model: !auth.isUpgraded() + ? { + providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId, + modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId, + } + : { + providerId: GPT_4_TURBO_MODEL_CONFIG.providerId, + modelId: GPT_4_TURBO_MODEL_CONFIG.modelId, + }, temperature: 0.4, }, action: { @@ -567,10 +553,8 @@ async function _getNotionGlobalAgent( async function _getDustGlobalAgent( auth: Authenticator, { - plan, settings, }: { - plan: PlanType; settings: GlobalAgentSettings | null; } ): Promise { @@ -647,16 +631,15 @@ async function _getDustGlobalAgent( id: -1, prompt: "Assist the user based on the retrieved data from their workspace. Unlesss the user explicitely asks for a detailed answer, you goal is to provide a quick answer to their question.", - model: - plan.code === FREE_TEST_PLAN_CODE - ? { - providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId, - modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId, - } - : { - providerId: GPT_4_TURBO_MODEL_CONFIG.providerId, - modelId: GPT_4_TURBO_MODEL_CONFIG.modelId, - }, + model: !auth.isUpgraded() + ? { + providerId: GPT_3_5_TURBO_MODEL_CONFIG.providerId, + modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId, + } + : { + providerId: GPT_4_TURBO_MODEL_CONFIG.providerId, + modelId: GPT_4_TURBO_MODEL_CONFIG.modelId, + }, temperature: 0.4, }, action: { @@ -693,11 +676,6 @@ export async function getGlobalAgent( throw new Error("Cannot find Global Agent Configuration: no workspace."); } - const plan = auth.plan(); - if (!plan) { - throw new Error("Unexpected `auth` without `plan`."); - } - if (preFetchedDataSources === null) { const prodCredentials = await prodAPICredentialsForOwner(owner); const api = new DustAPI(prodCredentials, logger); @@ -721,18 +699,18 @@ export async function getGlobalAgent( agentConfiguration = await _getGPT35TurboGlobalAgent({ settings }); break; case GLOBAL_AGENTS_SID.GPT4: - agentConfiguration = await _getGPT4GlobalAgent({ plan }); + agentConfiguration = await _getGPT4GlobalAgent({ auth }); break; case GLOBAL_AGENTS_SID.CLAUDE_INSTANT: agentConfiguration = await _getClaudeInstantGlobalAgent({ settings }); break; case GLOBAL_AGENTS_SID.CLAUDE: - agentConfiguration = await _getClaudeGlobalAgent({ settings, plan }); + agentConfiguration = await _getClaudeGlobalAgent({ auth, settings }); break; case GLOBAL_AGENTS_SID.MISTRAL_MEDIUM: agentConfiguration = await _getMistralMediumGlobalAgent({ - plan, settings, + auth, }); break; case GLOBAL_AGENTS_SID.MISTRAL_SMALL: @@ -766,7 +744,7 @@ export async function getGlobalAgent( }); break; case GLOBAL_AGENTS_SID.DUST: - agentConfiguration = await _getDustGlobalAgent(auth, { plan, settings }); + agentConfiguration = await _getDustGlobalAgent(auth, { settings }); break; default: return null; diff --git a/front/lib/api/assistant/pubsub.ts b/front/lib/api/assistant/pubsub.ts index bce045a601ac..dc3c5fbd32cd 100644 --- a/front/lib/api/assistant/pubsub.ts +++ b/front/lib/api/assistant/pubsub.ts @@ -56,6 +56,9 @@ export async function postUserMessageWithPubSub( let rateLimitKey: string | undefined = ""; if (auth.user()?.id) { maxPerTimeframe = 50; + if (auth.isUpgraded()) { + maxPerTimeframe = 200; + } timeframeSeconds = 60 * 60 * 3; rateLimitKey = `postUserMessageUser:${auth.user()?.id}`; } else { diff --git a/front/lib/auth.ts b/front/lib/auth.ts index 661263a00356..6b28ab9016bb 100644 --- a/front/lib/auth.ts +++ b/front/lib/auth.ts @@ -22,6 +22,7 @@ import { } from "@app/lib/models"; import type { PlanAttributes } from "@app/lib/plans/free_plans"; import { FREE_TEST_PLAN_DATA } from "@app/lib/plans/free_plans"; +import { isUpgraded } from "@app/lib/plans/plan_codes"; import { new_id } from "@app/lib/utils"; import logger from "@app/logger/logger"; import { authOptions } from "@app/pages/api/auth/[...nextauth]"; @@ -328,6 +329,10 @@ export class Authenticator { return this._subscription ? this._subscription.plan : null; } + isUpgraded(): boolean { + return isUpgraded(this.plan()); + } + /** * This is a convenience method to get the user from the Authenticator. The returned UserType * object won't have the user's workspaces set. diff --git a/front/lib/plans/plan_codes.ts b/front/lib/plans/plan_codes.ts index 6f62d0b1e043..4597322a38a0 100644 --- a/front/lib/plans/plan_codes.ts +++ b/front/lib/plans/plan_codes.ts @@ -1,3 +1,5 @@ +import { PlanType } from "@dust-tt/types"; + // Current free plans: export const FREE_UPGRADED_PLAN_CODE = "FREE_UPGRADED_PLAN"; export const FREE_TEST_PLAN_CODE = "FREE_TEST_PLAN"; @@ -9,3 +11,15 @@ export const PRO_PLAN_SEAT_29_CODE = "PRO_PLAN_SEAT_29"; * ENT_PLAN_FAKE is not subscribable and is only used to display the Enterprise plan in the UI (hence it's not stored on the db). */ export const ENT_PLAN_FAKE_CODE = "ENT_PLAN_FAKE_CODE"; + +/** + * `isUpgraded` returns true if the plan has access to all features of Dust, including large + * language models (meaning it's either a paid plan or free plan with (eg friends and family, or + * free trial plan)). + * + * Note: We didn't go for isFree or isPayingWorkspace as we have "upgraded" plans that are free. + */ +export const isUpgraded = (plan: PlanType | null): boolean => { + if (!plan) return false; + return plan.code !== FREE_TEST_PLAN_CODE; +}; diff --git a/front/pages/api/w/[wId]/data_sources/index.ts b/front/pages/api/w/[wId]/data_sources/index.ts index a165c8fc4201..73c91288e3fd 100644 --- a/front/pages/api/w/[wId]/data_sources/index.ts +++ b/front/pages/api/w/[wId]/data_sources/index.ts @@ -7,7 +7,6 @@ import type { NextApiRequest, NextApiResponse } from "next"; import { getDataSources } from "@app/lib/api/data_sources"; import { Authenticator, getSession } from "@app/lib/auth"; import { DataSource } from "@app/lib/models"; -import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes"; import logger from "@app/logger/logger"; import { apiError, withLogging } from "@app/logger/withlogging"; @@ -140,7 +139,7 @@ async function handler( splitter_id: "base_v0", max_chunk_size: dataSourceMaxChunkSize, qdrant_config: - plan.code !== FREE_TEST_PLAN_CODE && NODE_ENV === "production" + auth.isUpgraded() && NODE_ENV === "production" ? { cluster: "dedicated-1", shadow_write_cluster: null, diff --git a/front/pages/poke/[wId]/index.tsx b/front/pages/poke/[wId]/index.tsx index a15464735a5c..fcce25215ffc 100644 --- a/front/pages/poke/[wId]/index.tsx +++ b/front/pages/poke/[wId]/index.tsx @@ -38,6 +38,7 @@ import { useSubmitFunction } from "@app/lib/client/utils"; import { FREE_TEST_PLAN_CODE, FREE_UPGRADED_PLAN_CODE, + isUpgraded, } from "@app/lib/plans/plan_codes"; import { getPlanInvitation } from "@app/lib/plans/subscription"; import { usePokePlans } from "@app/lib/swr"; @@ -657,7 +658,7 @@ const WorkspacePage = ({ variant="secondaryWarning" onClick={onDowngrade} disabled={ - subscription.plan.code === FREE_TEST_PLAN_CODE || + !isUpgraded(subscription.plan) || workspaceHasManagedDataSources } /> @@ -667,9 +668,7 @@ const WorkspacePage = ({ label="Upgrade to free upgraded plan" variant="tertiary" onClick={onUpgrade} - disabled={ - subscription.plan.code !== FREE_TEST_PLAN_CODE - } + disabled={isUpgraded(subscription.plan)} /> diff --git a/front/pages/w/[wId]/members/index.tsx b/front/pages/w/[wId]/members/index.tsx index 12abf983193f..25e4609124ba 100644 --- a/front/pages/w/[wId]/members/index.tsx +++ b/front/pages/w/[wId]/members/index.tsx @@ -28,7 +28,7 @@ import AppLayout from "@app/components/sparkle/AppLayout"; import { subNavigationAdmin } from "@app/components/sparkle/navigation"; import { SendNotificationsContext } from "@app/components/sparkle/Notification"; import { Authenticator, getSession, getUserFromSession } from "@app/lib/auth"; -import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes"; +import { isUpgraded } from "@app/lib/plans/plan_codes"; import { useMembers, useWorkspaceInvitations } from "@app/lib/swr"; import { classNames, isEmailValid } from "@app/lib/utils"; @@ -153,8 +153,7 @@ export default function WorkspaceAdmin({ size="sm" icon={Cog6ToothIcon} onClick={() => { - if (plan.code === FREE_TEST_PLAN_CODE) - setShowNoInviteLinkPopup(true); + if (!isUpgraded(plan)) setShowNoInviteLinkPopup(true); else setInviteSettingsModalOpen(true); }} /> @@ -278,8 +277,7 @@ export default function WorkspaceAdmin({ size="sm" icon={PlusIcon} onClick={() => { - if (plan.code === FREE_TEST_PLAN_CODE) - setShowNoInviteFreePlanPopup(true); + if (!isUpgraded(plan)) setShowNoInviteFreePlanPopup(true); else if (subscription.paymentFailingSince) setShowNoInviteFailedPaymentPopup(true); else setInviteEmailModalOpen(true); diff --git a/front/pages/w/[wId]/subscription/index.tsx b/front/pages/w/[wId]/subscription/index.tsx index 2f14a372579c..093fee5ab2ad 100644 --- a/front/pages/w/[wId]/subscription/index.tsx +++ b/front/pages/w/[wId]/subscription/index.tsx @@ -23,6 +23,7 @@ import { useSubmitFunction } from "@app/lib/client/utils"; import { FREE_TEST_PLAN_CODE, FREE_UPGRADED_PLAN_CODE, + isUpgraded, PRO_PLAN_SEAT_29_CODE, } from "@app/lib/plans/plan_codes"; import { getPlanInvitation } from "@app/lib/plans/subscription"; @@ -131,7 +132,7 @@ export default function Subscription({ if (content.checkoutUrl) { await router.push(content.checkoutUrl); } else if (content.success) { - await router.reload(); // We cannot swr the plan so we just reload the page. + router.reload(); // We cannot swr the plan so we just reload the page. } } }); @@ -166,7 +167,7 @@ export default function Subscription({ const isProcessing = isSubscribingPlan || isGoingToStripePortal; const plan = subscription.plan; - const chipColor = plan.code === FREE_TEST_PLAN_CODE ? "emerald" : "sky"; + const chipColor = !isUpgraded(plan) ? "emerald" : "sky"; const onClickProPlan = async () => handleSubscribePlan(); const onClickEnterprisePlan = () => {