Skip to content

Commit

Permalink
oauth: Notion connector dual-flow + front gating (#6372)
Browse files Browse the repository at this point in the history
* Notion connectors dual-flow oauth/nango

* oauth: gate notion

* fix
  • Loading branch information
spolu authored Jul 22, 2024
1 parent b0c3c70 commit 0a5782f
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 38 deletions.
91 changes: 67 additions & 24 deletions connectors/src/connectors/notion/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ import type {
ContentNodesViewType,
Result,
} from "@dust-tt/types";
import { Err, getNotionDatabaseTableId, Ok } from "@dust-tt/types";
import {
Err,
getNotionDatabaseTableId,
getOAuthConnectionAccessToken,
Ok,
} from "@dust-tt/types";
import { v4 as uuidv4 } from "uuid";

import { notionConfig } from "@connectors/connectors/notion/lib/config";
Expand All @@ -13,6 +18,7 @@ import {
launchNotionSyncWorkflow,
stopNotionSyncWorkflow,
} from "@connectors/connectors/notion/temporal/client";
import { apiConfig } from "@connectors/lib/api/config";
import { dataSourceConfigFromConnector } from "@connectors/lib/api/data_source_config";
import {
NotionConnectorState,
Expand All @@ -23,6 +29,7 @@ import {
getAccessTokenFromNango,
getConnectionFromNango,
} from "@connectors/lib/nango_helpers";
import { isDualUseOAuthConnectionId } from "@connectors/lib/oauth";
import mainLogger from "@connectors/logger/logger";
import { ConnectorResource } from "@connectors/resources/connector_resource";
import type { DataSourceConfig } from "@connectors/types/data_source_config";
Expand All @@ -35,6 +42,31 @@ const { getRequiredNangoNotionConnectorId } = notionConfig;

const logger = mainLogger.child({ provider: "notion" });

async function workspaceIdFromConnectionId(connectionId: string) {
if (isDualUseOAuthConnectionId(connectionId)) {
const tokRes = await getOAuthConnectionAccessToken({
config: apiConfig.getOAuthAPIConfig(),
logger,
provider: "notion",
connectionId,
});
if (tokRes.isErr()) {
return new Err("Error retrieving access token");
}
return new Ok(
(tokRes.value.scrubbed_raw_json as { workspace_id?: string })
.workspace_id ?? null
);
} else {
const connectionRes = await getConnectionFromNango({
connectionId: connectionId,
integrationId: getRequiredNangoNotionConnectorId(),
refreshToken: false,
});
return new Ok(connectionRes?.credentials?.raw?.workspace_id || null);
}
}

export class NotionConnectorManager extends BaseConnectorManager<null> {
static async create({
dataSourceConfig,
Expand All @@ -43,13 +75,28 @@ export class NotionConnectorManager extends BaseConnectorManager<null> {
dataSourceConfig: DataSourceConfig;
connectionId: NangoConnectionId;
}): Promise<Result<string, Error>> {
const nangoConnectionId = connectionId;
let notionAccessToken: string | null = null;

const notionAccessToken = await getAccessTokenFromNango({
connectionId: nangoConnectionId,
integrationId: getRequiredNangoNotionConnectorId(),
useCache: false,
});
if (isDualUseOAuthConnectionId(connectionId)) {
const tokRes = await getOAuthConnectionAccessToken({
config: apiConfig.getOAuthAPIConfig(),
logger,
provider: "notion",
connectionId,
});
if (tokRes.isErr()) {
return new Err(
new Error("Error retrieving access token: " + tokRes.error.message)
);
}
notionAccessToken = tokRes.value.access_token;
} else {
notionAccessToken = (await getAccessTokenFromNango({
connectionId: connectionId,
integrationId: getRequiredNangoNotionConnectorId(),
useCache: false,
})) as string;
}

const isValidToken = await validateAccessToken(notionAccessToken);
if (!isValidToken) {
Expand All @@ -61,7 +108,7 @@ export class NotionConnectorManager extends BaseConnectorManager<null> {
connector = await ConnectorResource.makeNew(
"notion",
{
connectionId: nangoConnectionId,
connectionId,
workspaceAPIKey: dataSourceConfig.workspaceAPIKey,
workspaceId: dataSourceConfig.workspaceId,
dataSourceName: dataSourceConfig.dataSourceName,
Expand Down Expand Up @@ -107,29 +154,25 @@ export class NotionConnectorManager extends BaseConnectorManager<null> {

if (connectionId) {
const oldConnectionId = c.connectionId;
const connectionRes = await getConnectionFromNango({
connectionId: oldConnectionId,
integrationId: getRequiredNangoNotionConnectorId(),
refreshToken: false,
});

const newConnectionRes = await getConnectionFromNango({
connectionId,
integrationId: getRequiredNangoNotionConnectorId(),
refreshToken: false,
});
const [workspaceIdRes, newWorkspaceIdRes] = await Promise.all([
workspaceIdFromConnectionId(oldConnectionId),
workspaceIdFromConnectionId(connectionId),
]);

const workspaceId = connectionRes?.credentials?.raw?.workspace_id || null;
const newWorkspaceId =
newConnectionRes?.credentials?.raw?.workspace_id || null;
if (workspaceIdRes.isErr() || newWorkspaceIdRes.isErr()) {
return new Err({
type: "connector_update_error",
message: "Error retrieving old workspace id",
});
}

if (!workspaceId || !newWorkspaceId) {
if (!workspaceIdRes.value || !newWorkspaceIdRes.value) {
return new Err({
type: "connector_update_error",
message: "Error retrieving nango connection info to update connector",
});
}
if (workspaceId !== newWorkspaceId) {
if (workspaceIdRes.value !== newWorkspaceIdRes.value) {
return new Err({
type: "connector_oauth_target_mismatch",
message: "Cannot change workspace of a Notion connector",
Expand Down
2 changes: 1 addition & 1 deletion connectors/src/connectors/notion/lib/notion_api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ export async function validateAccessToken(notionAccessToken: string) {
logger: notionClientLogger,
});
try {
await notionClient.search({ page_size: 1 });
await notionClient.users.me({});
} catch (e) {
return false;
}
Expand Down
40 changes: 31 additions & 9 deletions connectors/src/connectors/notion/temporal/activities.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ import type {
NotionGarbageCollectionMode,
} from "@dust-tt/types";
import type { PageObjectProperties, ParsedNotionBlock } from "@dust-tt/types";
import { assertNever, getNotionDatabaseTableId, slugify } from "@dust-tt/types";
import {
assertNever,
getNotionDatabaseTableId,
getOAuthConnectionAccessToken,
slugify,
} from "@dust-tt/types";
import { isFullBlock, isFullPage, isNotionClientError } from "@notionhq/client";
import type { PageObjectResponse } from "@notionhq/client/build/src/api-endpoints";
import { Context } from "@temporalio/activity";
Expand Down Expand Up @@ -41,6 +46,7 @@ import {
updateAllParentsFields,
} from "@connectors/connectors/notion/lib/parents";
import { getTagsForPage } from "@connectors/connectors/notion/lib/tags";
import { apiConfig } from "@connectors/lib/api/config";
import { dataSourceConfigFromConnector } from "@connectors/lib/api/data_source_config";
import { concurrentExecutor } from "@connectors/lib/async_utils";
import {
Expand All @@ -65,6 +71,7 @@ import {
NotionPage,
} from "@connectors/lib/models/notion";
import { getAccessTokenFromNango } from "@connectors/lib/nango_helpers";
import { isDualUseOAuthConnectionId } from "@connectors/lib/oauth";
import { redisClient } from "@connectors/lib/redis";
import { syncStarted, syncSucceeded } from "@connectors/lib/sync_status";
import { heartbeat } from "@connectors/lib/temporal";
Expand Down Expand Up @@ -549,15 +556,30 @@ export async function saveStartSync(connectorId: ModelId) {
}

export async function getNotionAccessToken(
nangoConnectionId: string
connectionId: string
): Promise<string> {
const notionAccessToken = await getAccessTokenFromNango({
connectionId: nangoConnectionId,
integrationId: getRequiredNangoNotionConnectorId(),
useCache: true,
});

return notionAccessToken;
if (isDualUseOAuthConnectionId(connectionId)) {
const tokRes = await getOAuthConnectionAccessToken({
config: apiConfig.getOAuthAPIConfig(),
logger,
provider: "notion",
connectionId,
});
if (tokRes.isErr()) {
logger.error(
{ connectionId, error: tokRes.error },
"Error retrieving Notion access token"
);
throw new Error("Error retrieving Notion access token");
}
return tokRes.value.access_token;
} else {
return getAccessTokenFromNango({
connectionId: connectionId,
integrationId: getRequiredNangoNotionConnectorId(),
useCache: true,
});
}
}

export async function shouldGarbageCollect({
Expand Down
6 changes: 2 additions & 4 deletions front/pages/w/[wId]/builder/data-sources/managed.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,9 @@ export async function setupConnection({
isOAuthProvider(provider) &&
// `oauth`-ready providers
(["github", "slack"].includes(provider) ||
(["intercom"].includes(provider) &&
// Behind flag oauth-ready providers
(["intercom", "notion"].includes(provider) &&
owner.flags.includes("test_oauth_setup")))
// Behind flag oauth-ready providers
// ([""].includes(provider) &&
// owner.flags.includes("test_oauth_setup"))
) {
// OAuth flow
const cRes = await setupOAuthConnection({
Expand Down

0 comments on commit 0a5782f

Please sign in to comment.