Skip to content

Commit

Permalink
getOAuthConnectionAccessTokenWithThrow (#6387)
Browse files Browse the repository at this point in the history
* getOAuthConnectionAccessTokenWithThrow

* fix imports

* lint
  • Loading branch information
spolu authored Jul 22, 2024
1 parent 0348780 commit e96ddc3
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 73 deletions.
14 changes: 3 additions & 11 deletions connectors/src/connectors/github/lib/github_api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
} from "@connectors/connectors/github/lib/github_graphql";
import { apiConfig } from "@connectors/lib/api/config";
import { ExternalOauthTokenError } from "@connectors/lib/error";
import { getOAuthConnectionAccessTokenWithThrow } from "@connectors/lib/oauth";
import logger from "@connectors/logger/logger";

const API_PAGE_SIZE = 100;
Expand Down Expand Up @@ -543,22 +544,13 @@ export async function getDiscussion(
}

export async function getOctokit(connectionId: string): Promise<Octokit> {
const tokRes = await getOAuthConnectionAccessToken({
config: apiConfig.getOAuthAPIConfig(),
const token = await getOAuthConnectionAccessTokenWithThrow({
logger,
provider: "github",
connectionId,
});

if (tokRes.isErr()) {
logger.error(
{ connectionId, error: tokRes.error },
"Error retrieving Github access token"
);
throw new Error("Error retrieving Github access token");
}

return new Octokit({ auth: tokRes.value.access_token });
return new Octokit({ auth: token.access_token });
}

// Repository processing
Expand Down
26 changes: 10 additions & 16 deletions connectors/src/connectors/google_drive/temporal/utils.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import { cacheWithRedis, getOAuthConnectionAccessToken } from "@dust-tt/types";
import { cacheWithRedis } from "@dust-tt/types";
import type { drive_v3 } from "googleapis";
import { google } from "googleapis";
import { OAuth2Client } from "googleapis-common";

import { googleDriveConfig } from "@connectors/connectors/google_drive/lib/config";
import { apiConfig } from "@connectors/lib/api/config";
import type { NangoConnectionResponse } from "@connectors/lib/nango_helpers";
import { getConnectionFromNango } from "@connectors/lib/nango_helpers";
import { isDualUseOAuthConnectionId } from "@connectors/lib/oauth";
import {
getOAuthConnectionAccessTokenWithThrow,
isDualUseOAuthConnectionId,
} from "@connectors/lib/oauth";
import logger from "@connectors/logger/logger";
import type { GoogleDriveObjectType } from "@connectors/types/google_drive";

Expand Down Expand Up @@ -133,26 +135,18 @@ export async function getAuthObject(
): Promise<OAuth2Client> {
const oauth2Client = new google.auth.OAuth2();
if (isDualUseOAuthConnectionId(connectionId)) {
const tokRes = await getOAuthConnectionAccessToken({
config: apiConfig.getOAuthAPIConfig(),
const token = await getOAuthConnectionAccessTokenWithThrow({
logger,
provider: "google_drive",
connectionId,
});
if (tokRes.isErr()) {
logger.error(
{ connectionId, error: tokRes.error },
"Error retrieving Google access token"
);
throw new Error("Error retrieving Google access token");
}

oauth2Client.setCredentials({
access_token: tokRes.value.access_token,
scope: (tokRes.value.scrubbed_raw_json as { scope: string }).scope,
token_type: (tokRes.value.scrubbed_raw_json as { token_type: string })
access_token: token.access_token,
scope: (token.scrubbed_raw_json as { scope: string }).scope,
token_type: (token.scrubbed_raw_json as { token_type: string })
.token_type,
expiry_date: tokRes.value.access_token_expiry,
expiry_date: token.access_token_expiry,
});
} else {
const res: NangoConnectionResponse = await getConnectionFromNango({
Expand Down
21 changes: 6 additions & 15 deletions connectors/src/connectors/intercom/lib/intercom_access_token.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { getOAuthConnectionAccessToken } from "@dust-tt/types";

import { apiConfig } from "@connectors/lib/api/config";
import { getAccessTokenFromNango } from "@connectors/lib/nango_helpers";
import { isDualUseOAuthConnectionId } from "@connectors/lib/oauth";
import {
getOAuthConnectionAccessTokenWithThrow,
isDualUseOAuthConnectionId,
} from "@connectors/lib/oauth";
import logger from "@connectors/logger/logger";

const { NANGO_INTERCOM_CONNECTOR_ID } = process.env;
Expand All @@ -11,21 +11,12 @@ export async function getIntercomAccessToken(
connectionId: string
): Promise<string> {
if (isDualUseOAuthConnectionId(connectionId)) {
const tokRes = await getOAuthConnectionAccessToken({
config: apiConfig.getOAuthAPIConfig(),
const token = await getOAuthConnectionAccessTokenWithThrow({
logger,
provider: "intercom",
connectionId,
});
if (tokRes.isErr()) {
logger.error(
{ connectionId, error: tokRes.error },
"Error retrieving Intercom access token"
);
throw new Error("Error retrieving Intercom access token");
}

return tokRes.value.access_token;
return token.access_token;
} else {
// TODO(@fontanierh) INTERCOM_MIGRATION remove once migrated
if (!NANGO_INTERCOM_CONNECTOR_ID) {
Expand Down
25 changes: 7 additions & 18 deletions connectors/src/connectors/notion/temporal/activities.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@ import type {
NotionGarbageCollectionMode,
} from "@dust-tt/types";
import type { PageObjectProperties, ParsedNotionBlock } from "@dust-tt/types";
import {
assertNever,
getNotionDatabaseTableId,
getOAuthConnectionAccessToken,
slugify,
} from "@dust-tt/types";
import { assertNever, getNotionDatabaseTableId, slugify } from "@dust-tt/types";
import { isFullBlock, isFullPage, isNotionClientError } from "@notionhq/client";
import type { PageObjectResponse } from "@notionhq/client/build/src/api-endpoints";
import { Context } from "@temporalio/activity";
Expand Down Expand Up @@ -46,7 +41,6 @@ import {
updateAllParentsFields,
} from "@connectors/connectors/notion/lib/parents";
import { getTagsForPage } from "@connectors/connectors/notion/lib/tags";
import { apiConfig } from "@connectors/lib/api/config";
import { dataSourceConfigFromConnector } from "@connectors/lib/api/data_source_config";
import { concurrentExecutor } from "@connectors/lib/async_utils";
import {
Expand All @@ -71,7 +65,10 @@ import {
NotionPage,
} from "@connectors/lib/models/notion";
import { getAccessTokenFromNango } from "@connectors/lib/nango_helpers";
import { isDualUseOAuthConnectionId } from "@connectors/lib/oauth";
import {
getOAuthConnectionAccessTokenWithThrow,
isDualUseOAuthConnectionId,
} from "@connectors/lib/oauth";
import { redisClient } from "@connectors/lib/redis";
import { syncStarted, syncSucceeded } from "@connectors/lib/sync_status";
import { heartbeat } from "@connectors/lib/temporal";
Expand Down Expand Up @@ -559,20 +556,12 @@ export async function getNotionAccessToken(
connectionId: string
): Promise<string> {
if (isDualUseOAuthConnectionId(connectionId)) {
const tokRes = await getOAuthConnectionAccessToken({
config: apiConfig.getOAuthAPIConfig(),
const token = await getOAuthConnectionAccessTokenWithThrow({
logger,
provider: "notion",
connectionId,
});
if (tokRes.isErr()) {
logger.error(
{ connectionId, error: tokRes.error },
"Error retrieving Notion access token"
);
throw new Error("Error retrieving Notion access token");
}
return tokRes.value.access_token;
return token.access_token;
} else {
return getAccessTokenFromNango({
connectionId: connectionId,
Expand Down
15 changes: 3 additions & 12 deletions connectors/src/connectors/slack/lib/slack_client.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import type { ModelId } from "@dust-tt/types";
import { getOAuthConnectionAccessToken } from "@dust-tt/types";
import type {
CodedError,
WebAPIHTTPError,
WebAPIPlatformError,
} from "@slack/web-api";
import { ErrorCode, WebClient } from "@slack/web-api";

import { apiConfig } from "@connectors/lib/api/config";
import {
ExternalOauthTokenError,
ProviderWorkflowError,
} from "@connectors/lib/error";
import { getOAuthConnectionAccessTokenWithThrow } from "@connectors/lib/oauth";
import logger from "@connectors/logger/logger";
import { ConnectorResource } from "@connectors/resources/connector_resource";

Expand Down Expand Up @@ -183,19 +182,11 @@ export async function getSlackConversationInfo(
export async function getSlackAccessToken(
connectionId: string
): Promise<string> {
const tokRes = await getOAuthConnectionAccessToken({
config: apiConfig.getOAuthAPIConfig(),
const token = await getOAuthConnectionAccessTokenWithThrow({
logger,
provider: "slack",
connectionId,
});
if (tokRes.isErr()) {
logger.error(
{ connectionId, error: tokRes.error },
"Error retrieving Slack access token"
);
throw new Error("Error retrieving Slack access token");
}

return tokRes.value.access_token;
return token.access_token;
}
47 changes: 47 additions & 0 deletions connectors/src/lib/oauth.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,54 @@
import type { OAuthConnectionType, OAuthProvider } from "@dust-tt/types";
import { getOAuthConnectionAccessToken } from "@dust-tt/types";
import type { LoggerInterface } from "@dust-tt/types/dist/shared/logger";

import { apiConfig } from "@connectors/lib/api/config";
import { ExternalOauthTokenError } from "@connectors/lib/error";

// This function is used to discreminate between a new OAuth connection and an old Nango/Github
// connection. It is used to support dual-use while migrating and should be unused by a connector
// once fully migrated
export function isDualUseOAuthConnectionId(connectionId: string): boolean {
// TODO(spolu): make sure this function is removed once fully migrated.
return connectionId.startsWith("con_");
}

// Most connectors are built on the assumption that errors are thrown with special handling of
// selected errors such as ExternalOauthTokenError. This function is used to retrieve an OAuth
// connection access token and throw an ExternalOauthTokenError if the token is revoked.
export async function getOAuthConnectionAccessTokenWithThrow({
logger,
provider,
connectionId,
}: {
logger: LoggerInterface;
provider: OAuthProvider;
connectionId: string;
}): Promise<{
connection: OAuthConnectionType;
access_token: string;
access_token_expiry: number;
scrubbed_raw_json: unknown;
}> {
const tokRes = await getOAuthConnectionAccessToken({
config: apiConfig.getOAuthAPIConfig(),
logger,
provider,
connectionId,
});

if (tokRes.isErr()) {
logger.error(
{ connectionId, error: tokRes.error, provider },
"Error retrieving access token"
);

if (tokRes.error.code === "token_revoked_error") {
throw new ExternalOauthTokenError();
} else {
throw new Error(`Error retrieving access token from ${provider}`);
}
}

return tokRes.value;
}
2 changes: 1 addition & 1 deletion core/src/oauth/providers/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub async fn execute_request(
.unwrap_or_else(|_| String::from("Unable to read response body"));

return Err(ProviderHttpRequestError::RequestFailed {
provider: provider,
provider,
status: status.as_u16(),
message: body,
});
Expand Down

0 comments on commit e96ddc3

Please sign in to comment.