diff --git a/connectors/src/connectors/notion/index.ts b/connectors/src/connectors/notion/index.ts index 62102bbf3bddd..6f6f5cea10393 100644 --- a/connectors/src/connectors/notion/index.ts +++ b/connectors/src/connectors/notion/index.ts @@ -4,7 +4,12 @@ import type { ContentNodesViewType, Result, } from "@dust-tt/types"; -import { Err, getNotionDatabaseTableId, Ok } from "@dust-tt/types"; +import { + Err, + getNotionDatabaseTableId, + getOAuthConnectionAccessToken, + Ok, +} from "@dust-tt/types"; import { v4 as uuidv4 } from "uuid"; import { notionConfig } from "@connectors/connectors/notion/lib/config"; @@ -13,6 +18,7 @@ import { launchNotionSyncWorkflow, stopNotionSyncWorkflow, } from "@connectors/connectors/notion/temporal/client"; +import { apiConfig } from "@connectors/lib/api/config"; import { dataSourceConfigFromConnector } from "@connectors/lib/api/data_source_config"; import { NotionConnectorState, @@ -23,6 +29,7 @@ import { getAccessTokenFromNango, getConnectionFromNango, } from "@connectors/lib/nango_helpers"; +import { isDualUseOAuthConnectionId } from "@connectors/lib/oauth"; import mainLogger from "@connectors/logger/logger"; import { ConnectorResource } from "@connectors/resources/connector_resource"; import type { DataSourceConfig } from "@connectors/types/data_source_config"; @@ -35,6 +42,31 @@ const { getRequiredNangoNotionConnectorId } = notionConfig; const logger = mainLogger.child({ provider: "notion" }); +async function workspaceIdFromConnectionId(connectionId: string) { + if (isDualUseOAuthConnectionId(connectionId)) { + const tokRes = await getOAuthConnectionAccessToken({ + config: apiConfig.getOAuthAPIConfig(), + logger, + provider: "notion", + connectionId, + }); + if (tokRes.isErr()) { + return new Err("Error retrieving access token"); + } + return new Ok( + (tokRes.value.scrubbed_raw_json as { workspace_id?: string }) + .workspace_id ?? null + ); + } else { + const connectionRes = await getConnectionFromNango({ + connectionId: connectionId, + integrationId: getRequiredNangoNotionConnectorId(), + refreshToken: false, + }); + return new Ok(connectionRes?.credentials?.raw?.workspace_id || null); + } +} + export class NotionConnectorManager extends BaseConnectorManager { static async create({ dataSourceConfig, @@ -43,13 +75,28 @@ export class NotionConnectorManager extends BaseConnectorManager { dataSourceConfig: DataSourceConfig; connectionId: NangoConnectionId; }): Promise> { - const nangoConnectionId = connectionId; + let notionAccessToken: string | null = null; - const notionAccessToken = await getAccessTokenFromNango({ - connectionId: nangoConnectionId, - integrationId: getRequiredNangoNotionConnectorId(), - useCache: false, - }); + if (isDualUseOAuthConnectionId(connectionId)) { + const tokRes = await getOAuthConnectionAccessToken({ + config: apiConfig.getOAuthAPIConfig(), + logger, + provider: "notion", + connectionId, + }); + if (tokRes.isErr()) { + return new Err( + new Error("Error retrieving access token: " + tokRes.error.message) + ); + } + notionAccessToken = tokRes.value.access_token; + } else { + notionAccessToken = (await getAccessTokenFromNango({ + connectionId: connectionId, + integrationId: getRequiredNangoNotionConnectorId(), + useCache: false, + })) as string; + } const isValidToken = await validateAccessToken(notionAccessToken); if (!isValidToken) { @@ -61,7 +108,7 @@ export class NotionConnectorManager extends BaseConnectorManager { connector = await ConnectorResource.makeNew( "notion", { - connectionId: nangoConnectionId, + connectionId, workspaceAPIKey: dataSourceConfig.workspaceAPIKey, workspaceId: dataSourceConfig.workspaceId, dataSourceName: dataSourceConfig.dataSourceName, @@ -107,29 +154,25 @@ export class NotionConnectorManager extends BaseConnectorManager { if (connectionId) { const oldConnectionId = c.connectionId; - const connectionRes = await getConnectionFromNango({ - connectionId: oldConnectionId, - integrationId: getRequiredNangoNotionConnectorId(), - refreshToken: false, - }); - - const newConnectionRes = await getConnectionFromNango({ - connectionId, - integrationId: getRequiredNangoNotionConnectorId(), - refreshToken: false, - }); + const [workspaceIdRes, newWorkspaceIdRes] = await Promise.all([ + workspaceIdFromConnectionId(oldConnectionId), + workspaceIdFromConnectionId(connectionId), + ]); - const workspaceId = connectionRes?.credentials?.raw?.workspace_id || null; - const newWorkspaceId = - newConnectionRes?.credentials?.raw?.workspace_id || null; + if (workspaceIdRes.isErr() || newWorkspaceIdRes.isErr()) { + return new Err({ + type: "connector_update_error", + message: "Error retrieving old workspace id", + }); + } - if (!workspaceId || !newWorkspaceId) { + if (!workspaceIdRes.value || !newWorkspaceIdRes.value) { return new Err({ type: "connector_update_error", message: "Error retrieving nango connection info to update connector", }); } - if (workspaceId !== newWorkspaceId) { + if (workspaceIdRes.value !== newWorkspaceIdRes.value) { return new Err({ type: "connector_oauth_target_mismatch", message: "Cannot change workspace of a Notion connector", diff --git a/connectors/src/connectors/notion/lib/notion_api.ts b/connectors/src/connectors/notion/lib/notion_api.ts index 0a1e62d1406cf..a6a7f57f814fc 100644 --- a/connectors/src/connectors/notion/lib/notion_api.ts +++ b/connectors/src/connectors/notion/lib/notion_api.ts @@ -753,7 +753,7 @@ export async function validateAccessToken(notionAccessToken: string) { logger: notionClientLogger, }); try { - await notionClient.search({ page_size: 1 }); + await notionClient.users.me({}); } catch (e) { return false; } diff --git a/connectors/src/connectors/notion/temporal/activities.ts b/connectors/src/connectors/notion/temporal/activities.ts index 2f84c3fe5c4eb..b232baa61e3aa 100644 --- a/connectors/src/connectors/notion/temporal/activities.ts +++ b/connectors/src/connectors/notion/temporal/activities.ts @@ -4,7 +4,12 @@ import type { NotionGarbageCollectionMode, } from "@dust-tt/types"; import type { PageObjectProperties, ParsedNotionBlock } from "@dust-tt/types"; -import { assertNever, getNotionDatabaseTableId, slugify } from "@dust-tt/types"; +import { + assertNever, + getNotionDatabaseTableId, + getOAuthConnectionAccessToken, + slugify, +} from "@dust-tt/types"; import { isFullBlock, isFullPage, isNotionClientError } from "@notionhq/client"; import type { PageObjectResponse } from "@notionhq/client/build/src/api-endpoints"; import { Context } from "@temporalio/activity"; @@ -41,6 +46,7 @@ import { updateAllParentsFields, } from "@connectors/connectors/notion/lib/parents"; import { getTagsForPage } from "@connectors/connectors/notion/lib/tags"; +import { apiConfig } from "@connectors/lib/api/config"; import { dataSourceConfigFromConnector } from "@connectors/lib/api/data_source_config"; import { concurrentExecutor } from "@connectors/lib/async_utils"; import { @@ -65,6 +71,7 @@ import { NotionPage, } from "@connectors/lib/models/notion"; import { getAccessTokenFromNango } from "@connectors/lib/nango_helpers"; +import { isDualUseOAuthConnectionId } from "@connectors/lib/oauth"; import { redisClient } from "@connectors/lib/redis"; import { syncStarted, syncSucceeded } from "@connectors/lib/sync_status"; import { heartbeat } from "@connectors/lib/temporal"; @@ -549,15 +556,30 @@ export async function saveStartSync(connectorId: ModelId) { } export async function getNotionAccessToken( - nangoConnectionId: string + connectionId: string ): Promise { - const notionAccessToken = await getAccessTokenFromNango({ - connectionId: nangoConnectionId, - integrationId: getRequiredNangoNotionConnectorId(), - useCache: true, - }); - - return notionAccessToken; + if (isDualUseOAuthConnectionId(connectionId)) { + const tokRes = await getOAuthConnectionAccessToken({ + config: apiConfig.getOAuthAPIConfig(), + logger, + provider: "notion", + connectionId, + }); + if (tokRes.isErr()) { + logger.error( + { connectionId, error: tokRes.error }, + "Error retrieving Notion access token" + ); + throw new Error("Error retrieving Notion access token"); + } + return tokRes.value.access_token; + } else { + return getAccessTokenFromNango({ + connectionId: connectionId, + integrationId: getRequiredNangoNotionConnectorId(), + useCache: true, + }); + } } export async function shouldGarbageCollect({ diff --git a/front/pages/w/[wId]/builder/data-sources/managed.tsx b/front/pages/w/[wId]/builder/data-sources/managed.tsx index ec98a90708d49..8a5c551d8765a 100644 --- a/front/pages/w/[wId]/builder/data-sources/managed.tsx +++ b/front/pages/w/[wId]/builder/data-sources/managed.tsx @@ -110,11 +110,9 @@ export async function setupConnection({ isOAuthProvider(provider) && // `oauth`-ready providers (["github", "slack"].includes(provider) || - (["intercom"].includes(provider) && + // Behind flag oauth-ready providers + (["intercom", "notion"].includes(provider) && owner.flags.includes("test_oauth_setup"))) - // Behind flag oauth-ready providers - // ([""].includes(provider) && - // owner.flags.includes("test_oauth_setup")) ) { // OAuth flow const cRes = await setupOAuthConnection({