diff --git a/front/lib/labs/transcripts/utils/helpers.ts b/front/lib/labs/transcripts/utils/helpers.ts index 66a3f4f7bc35..d7d87a0a3b10 100644 --- a/front/lib/labs/transcripts/utils/helpers.ts +++ b/front/lib/labs/transcripts/utils/helpers.ts @@ -1,14 +1,14 @@ import type { ModelId, - NangoConnectionId, NangoIntegrationId, + OAuthProvider, Result, } from "@dust-tt/types"; -import { Err, Ok } from "@dust-tt/types"; +import { Err, getOAuthConnectionAccessToken, Ok } from "@dust-tt/types"; import { Nango } from "@nangohq/node"; import { google } from "googleapis"; -import type { OAuth2Client } from "googleapis-common"; +import apiConfig from "@app/lib/api/config"; import type { Authenticator } from "@app/lib/auth"; import config from "@app/lib/labs/config"; import { LabsTranscriptsConfigurationResource } from "@app/lib/resources/labs_transcripts_resource"; @@ -16,24 +16,12 @@ import logger from "@app/logger/logger"; const nango = new Nango({ secretKey: config.getNangoSecretKey() }); -// Google Auth -export async function getGoogleAuthObject( - nangoIntegrationId: NangoIntegrationId, - nangoConnectionId: NangoConnectionId -): Promise { - const res = await nango.getConnection(nangoIntegrationId, nangoConnectionId); - - const oauth2Client = new google.auth.OAuth2(); - oauth2Client.setCredentials({ - access_token: res.credentials.raw.access_token, - scope: res.credentials.raw.scope, - token_type: res.credentials.raw.token_type, - expiry_date: new Date(res.credentials.raw.expires_at).getTime(), - }); - - return oauth2Client; +export function isDualUseOAuthConnectionId(connectionId: string): boolean { + // TODO(spolu): make sure this function is removed once fully migrated. + return connectionId.startsWith("con_"); } +// Google Auth export async function getTranscriptsGoogleAuth( auth: Authenticator, userId: ModelId @@ -52,15 +40,54 @@ export async function getTranscriptsGoogleAuth( return; } - return getGoogleAuthObject( - config.getNangoConnectorIdForProvider("google_drive"), - transcriptsConfiguration.connectionId - ); + const connectionId = transcriptsConfiguration.connectionId; + const provider: OAuthProvider = "google_drive"; + + const oauth2Client = new google.auth.OAuth2(); + + if (isDualUseOAuthConnectionId(connectionId)) { + const tokRes = await getOAuthConnectionAccessToken({ + config: apiConfig.getOAuthAPIConfig(), + logger, + provider, + connectionId, + }); + + if (tokRes.isErr()) { + logger.error( + { connectionId, error: tokRes.error, provider }, + "Error retrieving access token" + ); + throw new Error(`Error retrieving access token from ${provider}`); + } + + oauth2Client.setCredentials({ + access_token: tokRes.value.access_token, + scope: (tokRes.value.scrubbed_raw_json as { scope: string }).scope, + token_type: (tokRes.value.scrubbed_raw_json as { token_type: string }) + .token_type, + expiry_date: tokRes.value.access_token_expiry, + }); + } else { + const res = await nango.getConnection( + config.getNangoConnectorIdForProvider("google_drive"), + connectionId + ); + + oauth2Client.setCredentials({ + access_token: res.credentials.raw.access_token, + scope: res.credentials.raw.scope, + token_type: res.credentials.raw.token_type, + expiry_date: new Date(res.credentials.raw.expires_at).getTime(), + }); + } + + return oauth2Client; } export async function getAccessTokenFromNango( nangoIntegrationId: NangoIntegrationId, - nangoConnectionId: NangoConnectionId + nangoConnectionId: string ): Promise { const res = await nango.getConnection(nangoIntegrationId, nangoConnectionId); diff --git a/front/lib/resources/labs_transcripts_resource.ts b/front/lib/resources/labs_transcripts_resource.ts index f2b309884ccb..645098895ad1 100644 --- a/front/lib/resources/labs_transcripts_resource.ts +++ b/front/lib/resources/labs_transcripts_resource.ts @@ -1,4 +1,8 @@ -import type { LabsConnectorProvider, Result } from "@dust-tt/types"; +import type { + LabsConnectorProvider, + LabsTranscriptsProviderType, + Result, +} from "@dust-tt/types"; import { Err, Ok } from "@dust-tt/types"; import type { Attributes, @@ -52,6 +56,27 @@ export class LabsTranscriptsConfigurationResource extends BaseResource { + const configurations = await LabsTranscriptsConfigurationModel.findAll({ + where: { + provider, + }, + }); + + return configurations.map( + (configuration) => + new LabsTranscriptsConfigurationResource( + LabsTranscriptsConfigurationModel, + configuration.get() + ) + ); + } + static async findByUserAndWorkspace({ auth, userId, @@ -142,6 +167,11 @@ export class LabsTranscriptsConfigurationResource extends BaseResource = { + google_drive: NANGO_GOOGLE_DRIVE_CONNECTOR_ID, + gong: NANGO_GONG_CONNECTOR_ID, +}; + +const CONNECTORS_WITH_REFRESH_TOKENS = ["google_drive"]; + +async function appendRollbackCommand( + provider: LabsTranscriptsProviderType, + labsTranscriptConfigurationId: ModelId, + oldConnectionId: string +) { + const sql = `UPDATE labs_transcripts_configurations SET "connectionId" = '${oldConnectionId}' WHERE id = ${labsTranscriptConfigurationId};\n`; + await fs.appendFile(`${provider}_rollback_commands.sql`, sql); +} + +function getRedirectUri(provider: LabsTranscriptsProviderType): string { + return `${config.getDustAPIConfig().url}/oauth/${provider}/finalize`; +} + +async function migrateConfigurationId( + api: OAuthAPI, + provider: LabsTranscriptsProviderType, + configuration: LabsTranscriptsConfigurationResource, + logger: Logger, + execute: boolean +): Promise> { + logger.info( + `Migrating configuration id ${configuration.id}, current connectionId ${configuration.connectionId}.` + ); + + const user = await configuration.getUser(); + const workspace = await Workspace.findOne({ + where: { id: configuration.workspaceId }, + }); + + if (!user || !workspace) { + return new Err(new Error("User or workspace not found")); + } + + const integrationId = NANGO_CONNECTOR_IDS[provider]; + if (!integrationId) { + return new Err(new Error("Nango integration ID not found for provider")); + } + + // Retrieve connection from nango. + let connection: any | null = null; + try { + connection = await nango.getConnection( + integrationId, + configuration.connectionId, + true, // forceRefresh + true // returnRefreshToksn + ); + } catch (e) { + return new Err(new Error(`Nango error: ${e}`)); + } + + console.log( + ">>>>>>>>>>>>>>>>>>>>>>>>>>> BEG CONNECTION <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" + ); + console.log(connection); + console.log( + ">>>>>>>>>>>>>>>>>>>>>>>>>>> END CONNECTION <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" + ); + + if (!connection.credentials.access_token) { + return new Err(new Error("Could not retrieve `access_token` from Nango")); + } + + // We don't have authorization codes from Nango + const migratedCredentials: MigratedCredentialsType = { + redirect_uri: getRedirectUri(provider), + access_token: connection.credentials.access_token, + raw_json: connection.credentials.raw, + }; + + // If provider supports refresh tokens, migrate them. + if (CONNECTORS_WITH_REFRESH_TOKENS.includes(provider)) { + const thirtyMinutesFromNow = new Date(new Date().getTime() + 30 * 60000); + + if ( + !connection.credentials.expires_at || + new Date(connection.credentials.expires_at).getTime() < + thirtyMinutesFromNow.getTime() + ) { + return new Err( + new Error( + "Expires at is not set or is less than 30 minutes from now. Skipping migration." + ) + ); + } + + if (connection.credentials.expires_at) { + migratedCredentials.access_token_expiry = Date.parse( + connection.credentials.expires_at + ); + } + if (connection.credentials.refresh_token) { + migratedCredentials.refresh_token = connection.credentials.refresh_token; + } + } + + console.log( + ">>>>>>>>>>>>>>>>>>>>>>>>>>> BEG MIGRATED_CREDENTIALS <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" + ); + console.log(migratedCredentials); + console.log( + ">>>>>>>>>>>>>>>>>>>>>>>>>>> END MIGRATED_CREDENTIALS <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" + ); + + if (!execute) { + return new Ok(undefined); + } + + // Save the old connectionId for rollback. + const oldConnectionId = configuration.connectionId; + + // Create the connection with migratedCredentials. + const cRes = await api.createConnection({ + // TOOD(alban): remove the as once gong is an OAuthProvider. + provider: provider as OAuthProvider, + metadata: { + use_case: USE_CASE, + workspace_id: workspace.sId, + user_id: user.sId, + origin: "migrated", + }, + migratedCredentials, + }); + + if (cRes.isErr()) { + return cRes; + } + + const newConnectionId = cRes.value.connection.connection_id; + + // Append rollback command after successful update. + await appendRollbackCommand(provider, configuration.id, oldConnectionId); + + await configuration.updateConnectionId(newConnectionId); + + logger.info( + `Successfully migrated connection id for connector ${configuration.id}, new connectionId ${newConnectionId}.` + ); + + return new Ok(undefined); +} + +async function migrateAllConfigurations( + provider: LabsTranscriptsProviderType, + configurationId: ModelId | undefined, + logger: Logger, + execute: boolean +) { + const api = new OAuthAPI(config.getOAuthAPIConfig(), logger); + + const configurations = configurationId + ? removeNulls([ + await LabsTranscriptsConfigurationResource.fetchByModelId( + configurationId + ), + ]) + : await LabsTranscriptsConfigurationResource.listByProvider({ + provider, + }); + + logger.info( + `Found ${configurations.length} ${provider} configurations to migrate.` + ); + + for (const configuration of configurations) { + const localLogger = logger.child({ + configurationId: configuration.id, + workspaceId: configuration.workspaceId, + }); + + if (isDualUseOAuthConnectionId(configuration.connectionId)) { + localLogger.info("Skipping alreaydy migrated configuration"); + continue; + } + + const migrationRes = await migrateConfigurationId( + api, + provider, + configuration, + localLogger, + execute + ); + if (migrationRes.isErr()) { + localLogger.error( + { + error: migrationRes.error, + }, + "Failed to migrate configuration. Exiting." + ); + } + } + + logger.info(`Done migrating configurations.`); +} + +makeScript( + { + connectorId: { + alias: "c", + describe: "Connector ID", + type: "number", + }, + provider: { + alias: "p", + describe: "OAuth provider to migrate", + type: "string", + }, + }, + async ({ provider, connectorId, execute }, logger) => { + if (isOAuthProvider(provider)) { + await migrateAllConfigurations( + provider as LabsTranscriptsProviderType, + connectorId, + logger, + execute + ); + } else { + logger.error( + { + provider, + }, + "Invalid provider provided" + ); + } + } +); diff --git a/front/pages/w/[wId]/assistant/labs/transcripts/index.tsx b/front/pages/w/[wId]/assistant/labs/transcripts/index.tsx index ec3a2ee998a2..27845800a0cd 100644 --- a/front/pages/w/[wId]/assistant/labs/transcripts/index.tsx +++ b/front/pages/w/[wId]/assistant/labs/transcripts/index.tsx @@ -8,13 +8,13 @@ import { Spinner, XMarkIcon, } from "@dust-tt/sparkle"; +import type { SubscriptionType } from "@dust-tt/types"; +import type { LightAgentConfigurationType } from "@dust-tt/types"; import type { LabsTranscriptsProviderType, - UserType, WorkspaceType, } from "@dust-tt/types"; -import type { SubscriptionType } from "@dust-tt/types"; -import type { LightAgentConfigurationType } from "@dust-tt/types"; +import { setupOAuthConnection } from "@dust-tt/types"; import Nango from "@nangohq/frontend"; import type { InferGetServerSidePropsType } from "next"; import { useContext, useEffect, useState } from "react"; @@ -44,12 +44,11 @@ const defaultTranscriptConfigurationState = { export const getServerSideProps = withDefaultUserAuthRequirements<{ owner: WorkspaceType; - user: UserType; subscription: SubscriptionType; gaTrackingId: string; - nangoDriveConnectorId: string; nangoGongConnectorId: string; nangoPublicKey: string; + dustClientFacingUrl: string; }>(async (_context, auth) => { const owner = auth.workspace(); const subscription = auth.subscription(); @@ -69,25 +68,22 @@ export const getServerSideProps = withDefaultUserAuthRequirements<{ return { props: { owner, - user, subscription, gaTrackingId: apiConfig.getGaTrackingId(), - nangoDriveConnectorId: - config.getNangoConnectorIdForProvider("google_drive"), nangoGongConnectorId: config.getNangoConnectorIdForProvider("gong"), nangoPublicKey: config.getNangoPublicKey(), + dustClientFacingUrl: apiConfig.getClientFacingUrl(), }, }; }); export default function LabsTranscriptsIndex({ owner, - user, subscription, gaTrackingId, - nangoDriveConnectorId, nangoGongConnectorId, nangoPublicKey, + dustClientFacingUrl, }: InferGetServerSidePropsType) { const sendNotification = useContext(SendNotificationsContext); const [isDeleteProviderDialogOpened, setIsDeleteProviderDialogOpened] = @@ -249,70 +245,75 @@ export default function LabsTranscriptsIndex({ return updateIsActive(transcriptConfigurationId, isActive); }; - const saveOauthConnection = async ( + const saveOAuthConnection = async ( connectionId: string, provider: string ) => { - const response = await fetch(`/api/w/${owner.sId}/labs/transcripts`, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - connectionId, - provider, - }), - }); - - if (!response.ok) { + try { + const response = await fetch(`/api/w/${owner.sId}/labs/transcripts`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + connectionId, + provider, + }), + }); + if (!response.ok) { + sendNotification({ + type: "error", + title: "Failed to connect provider", + description: + "Could not connect to your transcripts provider. Please try again.", + }); + } else { + sendNotification({ + type: "success", + title: "Provider connected", + description: + "Your transcripts provider has been connected successfully.", + }); + + await mutateTranscriptsConfiguration(); + } + return response; + } catch (error) { sendNotification({ type: "error", title: "Failed to connect provider", description: - "Could not connect to your transcripts provider. Please try again.", + "Unexpected error trying to connect to your transcripts provider. Please try again. Error: " + + error, }); - } else { - sendNotification({ - type: "success", - title: "Provider connected", - description: - "Your transcripts provider has been connected successfully.", - }); - - await mutateTranscriptsConfiguration(); } - - return response; }; const handleConnectGoogleTranscriptsSource = async () => { - try { - if (transcriptsConfigurationState.provider !== "google_drive") { - return; - } - const nango = new Nango({ publicKey: nangoPublicKey }); - const nangoConnectionId = buildLabsConnectionId( - `labs-transcripts-workspace-${owner.id}-user-${user.id}`, - transcriptsConfigurationState.provider - ); - const { - connectionId: newConnectionId, - }: { providerConfigKey: string; connectionId: string } = await nango.auth( - nangoDriveConnectorId, - nangoConnectionId - ); + if (transcriptsConfigurationState.provider !== "google_drive") { + return; + } - await saveOauthConnection( - newConnectionId, - transcriptsConfigurationState.provider - ); - } catch (error) { + const cRes = await setupOAuthConnection({ + dustClientFacingUrl, + owner, + provider: "google_drive", + useCase: "labs_transcripts", + }); + + if (cRes.isErr()) { sendNotification({ type: "error", title: "Failed to connect Google Drive", - description: "Could not connect to Google Drive. Please try again.", + description: cRes.error.message, }); + return; } + + await saveOAuthConnection( + cRes.value.connection_id, + transcriptsConfigurationState.provider + ); }; const handleConnectGongTranscriptsSource = async () => { @@ -340,7 +341,7 @@ export default function LabsTranscriptsIndex({ return; } - await saveOauthConnection( + await saveOAuthConnection( defaultConfiguration.connectionId, transcriptsConfigurationState.provider ); @@ -357,7 +358,7 @@ export default function LabsTranscriptsIndex({ connectionId: newConnectionId, }: { providerConfigKey: string; connectionId: string } = await nango.auth(nangoGongConnectorId, nangoConnectionId); - await saveOauthConnection( + await saveOAuthConnection( newConnectionId, transcriptsConfigurationState.provider ); diff --git a/types/src/oauth/lib.ts b/types/src/oauth/lib.ts index 2488fd2d4e0b..e2f4fbc53de4 100644 --- a/types/src/oauth/lib.ts +++ b/types/src/oauth/lib.ts @@ -1,4 +1,4 @@ -export const OAUTH_USE_CASES = ["connection"] as const; +export const OAUTH_USE_CASES = ["connection", "labs_transcripts"] as const; export type OAuthUseCase = (typeof OAUTH_USE_CASES)[number];