From 4f7176a71f045402c6d12da1f17d9a18eff560df Mon Sep 17 00:00:00 2001 From: Flavien David Date: Fri, 19 Jul 2024 13:28:32 +0200 Subject: [PATCH] Add migration for Github connector to new OAuth (#6343) * Add migration for Github connector to new OAuth * :sparkles: * Fix rollback --- .../20240719_migrate_github_connection_id.ts | 132 ++++++++++++++++++ connectors/scripts/helpers.ts | 68 +++++++++ front/lib/api/oauth.ts | 18 +-- 3 files changed, 210 insertions(+), 8 deletions(-) create mode 100644 connectors/migrations/20240719_migrate_github_connection_id.ts create mode 100644 connectors/scripts/helpers.ts diff --git a/connectors/migrations/20240719_migrate_github_connection_id.ts b/connectors/migrations/20240719_migrate_github_connection_id.ts new file mode 100644 index 000000000000..c0d6c65735bd --- /dev/null +++ b/connectors/migrations/20240719_migrate_github_connection_id.ts @@ -0,0 +1,132 @@ +import type { ModelId, OAuthAPIError, Result } from "@dust-tt/types"; +import { OAuthAPI, Ok } from "@dust-tt/types"; +import { promises as fs } from "fs"; +import { makeScript } from "scripts/helpers"; + +import { apiConfig } from "@connectors/lib/api/config"; +import type { Logger } from "@connectors/logger/logger"; +import { ConnectorResource } from "@connectors/resources/connector_resource"; + +const PROVIDER = "github"; +const USE_CASE = "connection"; + +async function appendRollbackCommand( + connectorId: ModelId, + oldConnectionId: string +) { + const sql = `UPDATE connectors SET "connectionId" = '${oldConnectionId}' WHERE id = ${connectorId};\n`; + await fs.appendFile(`${PROVIDER}_rollback_commands.sql`, sql); +} + +function getRedirectUri(): string { + return `${apiConfig.getDustAPIConfig().url}/oauth/${PROVIDER}/finalize`; +} + +async function migrateGithubConnectionId( + api: OAuthAPI, + connector: ConnectorResource, + logger: Logger, + execute: boolean +): Promise> { + logger.info( + `Migrating connection id for connector ${connector.id}, current connectionId ${connector.connectionId}.` + ); + if (!execute) { + return new Ok(undefined); + } + + // Save the old connectionId for rollback. + const oldConnectionId = connector.connectionId; + + // First, we create the connection. + const cRes = await api.createConnection({ + provider: PROVIDER, + metadata: { + use_case: USE_CASE, + workspace_id: connector.workspaceId, + origin: "migrated", + }, + }); + + if (cRes.isErr()) { + return cRes; + } + + const newConnectionId = cRes.value.connection.connection_id; + + // Then we finalize the connection. + const fRes = await api.finalizeConnection({ + provider: PROVIDER, + connectionId: newConnectionId, + code: connector.connectionId, + redirectUri: getRedirectUri(), + }); + + if (fRes.isErr()) { + return fRes; + } + + // Append rollback command after successful update. + await appendRollbackCommand(connector.id, oldConnectionId); + + await connector.update({ + connectionId: newConnectionId, + }); + + logger.info( + `Successfully migrated connection id for connector ${connector.id}, new connectionId ${newConnectionId}.` + ); + + return new Ok(undefined); +} + +async function migrateAllGithubConnections( + connectorId: ModelId | undefined, + logger: Logger, + execute: boolean +) { + const api = new OAuthAPI(apiConfig.getOAuthAPIConfig(), logger); + + const connectors = connectorId + ? await ConnectorResource.fetchByIds(PROVIDER, [connectorId]) + : await ConnectorResource.listByType(PROVIDER, {}); + + logger.info(`Found ${connectors.length} GitHub connectors to migrate.`); + + for (const connector of connectors) { + const localLogger = logger.child({ + connectorId: connector.id, + workspaceId: connector.workspaceId, + }); + + const migrationRes = await migrateGithubConnectionId( + api, + connector, + localLogger, + execute + ); + if (migrationRes.isErr()) { + localLogger.error( + { + error: migrationRes.error, + }, + "Failed to migrate connector. Exiting." + ); + } + } + + logger.info(`Done migrating GitHub connectors.`); +} + +makeScript( + { + connectorId: { + alias: "c", + describe: "Connector ID", + type: "number", + }, + }, + async ({ connectorId, execute }, logger) => { + await migrateAllGithubConnections(connectorId, logger, execute); + } +); diff --git a/connectors/scripts/helpers.ts b/connectors/scripts/helpers.ts new file mode 100644 index 000000000000..200e72f7754c --- /dev/null +++ b/connectors/scripts/helpers.ts @@ -0,0 +1,68 @@ +import type { Options } from "yargs"; +import yargs from "yargs"; +import { hideBin } from "yargs/helpers"; + +import type { Logger } from "@connectors/logger/logger"; +import logger from "@connectors/logger/logger"; + +// Define a type for the argument specification object. +export type ArgumentSpecs = { + [key: string]: Options & { type?: "array" | "string" | "boolean" | "number" }; +}; + +// Define a type for the worker function. +type WorkerFunction = (args: T, logger: Logger) => Promise; + +// Define a utility type to infer the argument types from the argument specs. +type InferArgs = { + [P in keyof T]: T[P] extends { type: "number" } + ? number + : T[P] extends { type: "boolean" } + ? boolean + : T[P] extends { type: "string" } + ? string + : T[P] extends { type: "array" } + ? string[] + : never; +} & { execute?: boolean }; + +const defaultArgumentSpecs: ArgumentSpecs = { + execute: { + alias: "e", + describe: "Execute the script", + type: "boolean" as const, + default: false, + }, +}; + +export function makeScript( + argumentSpecs: T, + worker: WorkerFunction & { execute: boolean }> +): void { + const argv = yargs(hideBin(process.argv)); + + const combinedArgumentSpecs = { ...defaultArgumentSpecs, ...argumentSpecs }; + + // Configure yargs using the provided argument specifications. + Object.entries(combinedArgumentSpecs).forEach(([key, options]) => { + argv.option(key, options); + }); + + argv + .help("h") + .alias("h", "help") + .parseAsync() + .then(async (args) => { + const scriptLogger = logger.child({ + execute: args.execute, + }); + + await worker(args as InferArgs, scriptLogger); + + process.exit(0); + }) + .catch((error) => { + console.error("An error occurred:", error); + process.exit(1); + }); +} diff --git a/front/lib/api/oauth.ts b/front/lib/api/oauth.ts index 7158fbde9b5b..3550943d3731 100644 --- a/front/lib/api/oauth.ts +++ b/front/lib/api/oauth.ts @@ -234,15 +234,17 @@ export async function finalizeConnection( code, redirectUri: finalizeUriForProvider(provider), }); - logger.error( - { - provider, - connectionId, - step: "connection_finalization", - }, - "OAuth: Failed to finalize connection" - ); + if (cRes.isErr()) { + logger.error( + { + provider, + connectionId, + step: "connection_finalization", + }, + "OAuth: Failed to finalize connection" + ); + return new Err({ code: "connection_finalization_failed", message: `Failed to finalize ${provider} connection: ${cRes.error.message}`,