From c19ca2af30a76294912ed2e0b002c74e5880fbea Mon Sep 17 00:00:00 2001 From: Flavien David Date: Fri, 8 Mar 2024 17:37:43 +0100 Subject: [PATCH] Support workspace with SSO enforced (#4227) * Support workspace with SSO enforced * :see_no_evil: * :see_no_evil: --- front/lib/auth.ts | 1 + front/lib/iam/session.ts | 49 +++++++++++++++++++++++++++-------- front/lib/models/workspace.ts | 5 ++++ types/src/front/user.ts | 1 + types/src/front/workspace.ts | 12 +++++++++ 5 files changed, 57 insertions(+), 11 deletions(-) diff --git a/front/lib/auth.ts b/front/lib/auth.ts index ff1fea9f9159..7aacb3fe2172 100644 --- a/front/lib/auth.ts +++ b/front/lib/auth.ts @@ -395,6 +395,7 @@ export class Authenticator { ACTIVATE_ALL_FEATURES_DEV && isDevelopment() ? [...WHITELISTABLE_FEATURES] : this._flags, + ssoEnforced: this._workspace.ssoEnforced, } : null; } diff --git a/front/lib/iam/session.ts b/front/lib/iam/session.ts index c1ed9c901122..223c4241eb2c 100644 --- a/front/lib/iam/session.ts +++ b/front/lib/iam/session.ts @@ -1,4 +1,5 @@ import type { RoleType, UserTypeWithWorkspaces } from "@dust-tt/types"; +import { isEnterpriseConnectionSub } from "@dust-tt/types"; import type { GetServerSidePropsContext, GetServerSidePropsResult, @@ -103,6 +104,22 @@ export type CustomGetServerSideProps< session: RequireUserPrivilege extends "none" ? null : SessionWithUser ) => Promise>; +export function statisfiesEnforceEntrepriseConnection( + auth: Authenticator, + session: SessionWithUser +) { + const owner = auth.workspace(); + if (!owner) { + return true; + } + + if (owner.ssoEnforced) { + return isEnterpriseConnectionSub(session.user.sub); + } + + return true; +} + async function getAuthenticator( context: GetServerSidePropsContext, session: SessionWithUser | null, @@ -156,17 +173,27 @@ export function makeGetServerSidePropsRequirementsWrapper< requireUserPrivilege ); - if ( - requireUserPrivilege !== "none" && - (!session || !isValidSession(session)) - ) { - return { - redirect: { - permanent: false, - // TODO(2024-03-04 flav) Add support for `returnTo=`. - destination: "/api/auth/login", - }, - }; + if (requireUserPrivilege !== "none") { + if (!session || !isValidSession(session)) { + return { + redirect: { + permanent: false, + // TODO(2024-03-04 flav) Add support for `returnTo=`. + destination: "/api/auth/login", + }, + }; + } + + // Validate the user's session to guarantee compliance with the workspace's SSO requirements when SSO is enforced. + if (auth && !statisfiesEnforceEntrepriseConnection(auth, session)) { + return { + redirect: { + permanent: false, + // TODO(2024-03-04 flav) Add support for `returnTo=`. + destination: `/sso-enforced?workspaceId=${auth.workspace()?.sId}`, + }, + }; + } } const userSession = session as RequireUserPrivilege extends "none" diff --git a/front/lib/models/workspace.ts b/front/lib/models/workspace.ts index 77747a28cffb..eefa3ff3822c 100644 --- a/front/lib/models/workspace.ts +++ b/front/lib/models/workspace.ts @@ -25,6 +25,7 @@ export class Workspace extends Model< declare name: string; declare description: string | null; declare segmentation: WorkspaceSegmentationType; + declare ssoEnforced?: boolean; declare subscriptions: NonAttribute; } Workspace.init( @@ -63,6 +64,10 @@ Workspace.init( type: DataTypes.STRING, allowNull: true, }, + ssoEnforced: { + type: DataTypes.BOOLEAN, + defaultValue: false, + }, }, { modelName: "workspace", diff --git a/types/src/front/user.ts b/types/src/front/user.ts index d3c7b740020d..8c79dc1337af 100644 --- a/types/src/front/user.ts +++ b/types/src/front/user.ts @@ -15,6 +15,7 @@ export type LightWorkspaceType = { export type WorkspaceType = LightWorkspaceType & { flags: WhitelistableFeature[]; + ssoEnforced?: boolean; }; export type UserProviderType = "github" | "google" | null; diff --git a/types/src/front/workspace.ts b/types/src/front/workspace.ts index b644b7b28130..894367284dc2 100644 --- a/types/src/front/workspace.ts +++ b/types/src/front/workspace.ts @@ -8,3 +8,15 @@ export interface WorkspaceEnterpriseConnection { } export type SupportedEnterpriseConnectionStrategies = "okta"; +export const supportedEnterpriseConnectionStrategies: SupportedEnterpriseConnectionStrategies[] = + ["okta"]; + +export function isEnterpriseConnectionSub( + sub: string +): sub is SupportedEnterpriseConnectionStrategies { + const [provider] = sub.split("|"); + + return supportedEnterpriseConnectionStrategies.includes( + provider as SupportedEnterpriseConnectionStrategies + ); +}