Skip to content

Commit

Permalink
[extension] Use dedicated audience for dust api. Add scopes verificat…
Browse files Browse the repository at this point in the history
…ion (#8719)
  • Loading branch information
tdraier authored Nov 19, 2024
1 parent 5e705e7 commit 217bc8b
Show file tree
Hide file tree
Showing 19 changed files with 190 additions and 52 deletions.
8 changes: 4 additions & 4 deletions extension/app/background.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ import type { PendingUpdate } from "@extension/lib/storage";
import { savePendingUpdate } from "@extension/lib/storage";

import {
AUTH0_AUDIENCE,
AUTH0_CLIENT_DOMAIN,
AUTH0_CLIENT_ID,
DUST_API_AUDIENCE,
} from "./src/lib/config";
import { extractPage } from "./src/lib/extraction";
import type {
Expand Down Expand Up @@ -266,10 +266,10 @@ const authenticate = async (
const options = {
client_id: AUTH0_CLIENT_ID,
response_type: "code",
// "offline_access" to receive refresh tokens to maintain user sessions without re-prompting for authentication.
scope: "openid offline_access",
scope:
"offline_access read:user_profile read:conversation create:conversation update:conversation read:agent read:file create:file delete:file",
redirect_uri: redirectUrl,
audience: AUTH0_AUDIENCE,
audience: DUST_API_AUDIENCE,
code_challenge_method: "S256",
code_challenge: codeChallenge,
};
Expand Down
1 change: 1 addition & 0 deletions extension/app/src/lib/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ export const AUTH0_CLIENT_DOMAIN = process.env.AUTH0_CLIENT_DOMAIN ?? "";
export const AUTH0_CLIENT_ID = process.env.AUTH0_CLIENT_ID ?? "";
export const AUTH0_AUDIENCE = `https://${AUTH0_CLIENT_DOMAIN}/api/v2/`;
export const AUTH0_PROFILE_ROUTE = `https://${AUTH0_CLIENT_DOMAIN}/userinfo`;
export const DUST_API_AUDIENCE = process.env.DUST_API_AUDIENCE ?? "";
121 changes: 91 additions & 30 deletions front/lib/api/auth0.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,68 @@
import type { Result } from "@dust-tt/types";
import { Err, Ok } from "@dust-tt/types";
import { ManagementClient } from "auth0";
import { isLeft } from "fp-ts/lib/Either";
import * as t from "io-ts";
import jwt from "jsonwebtoken";
import jwksClient from "jwks-rsa";
import type { NextApiRequest } from "next";

import config from "@app/lib/api/config";
import { UserResource } from "@app/lib/resources/user_resource";
import logger from "@app/logger/logger";

let auth0ManagemementClient: ManagementClient | null = null;

export const SUPPORTED_METHODS = [
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
] as const;
export type MethodType = (typeof SUPPORTED_METHODS)[number];

const isSupportedMethod = (method: string): method is MethodType =>
SUPPORTED_METHODS.includes(method as MethodType);

export type ScopeType =
| "read:user_profile"
| "read:conversation"
| "update:conversation"
| "create:conversation"
| "read:file"
| "update:file"
| "create:file"
| "delete:file"
| "read:agent";

export const Auth0JwtPayloadSchema = t.type({
azp: t.string,
exp: t.number,
scope: t.string,
sub: t.string,
});

export type Auth0JwtPayload = t.TypeOf<typeof Auth0JwtPayloadSchema> &
jwt.JwtPayload;

export function getRequiredScope(
req: NextApiRequest,
requiredScopes?: Partial<Record<MethodType, ScopeType>>
) {
const method = req.method;

if (
method &&
isSupportedMethod(method) &&
requiredScopes &&
requiredScopes[method]
) {
return requiredScopes[method];
}
return undefined;
}

export function getAuth0ManagemementClient(): ManagementClient {
if (!auth0ManagemementClient) {
auth0ManagemementClient = new ManagementClient({
Expand Down Expand Up @@ -50,13 +105,21 @@ async function getSigningKey(jwksUri: string, kid: string): Promise<string> {
* Verify an Auth0 token.
* Not meant to be exported, use `getUserFromAuth0Token` instead.
*/
async function verifyAuth0Token(accessToken: string): Promise<jwt.JwtPayload> {
export async function verifyAuth0Token(
accessToken: string,
requiredScope?: ScopeType
): Promise<Result<Auth0JwtPayload, Error>> {
const auth0Domain = config.getAuth0TenantUrl();
const audience = `https://${auth0Domain}/api/v2/`;
const audience = config.getDustApiAudience();
const verify = `https://${auth0Domain}/.well-known/jwks.json`;
const issuer = `https://${auth0Domain}/`;

return new Promise((resolve, reject) => {
// TODO(thomas): Remove this when all clients are updated.
const legacyAudience = `https://${auth0Domain}/api/v2/`;
const decoded = jwt.decode(accessToken, { json: true });
const useLegacy = decoded && decoded.aud === legacyAudience;

return new Promise((resolve) => {
jwt.verify(
accessToken,
async (header, callback) => {
Expand All @@ -72,48 +135,46 @@ async function verifyAuth0Token(accessToken: string): Promise<jwt.JwtPayload> {
},
{
algorithms: ["RS256"],
audience: audience,
audience: useLegacy ? legacyAudience : audience,
issuer: issuer,
},
(err, decoded) => {
if (err) {
reject(err);
return;
return resolve(new Err(err));
}
if (!decoded || typeof decoded !== "object") {
reject(new Error("No token payload"));
return;
return resolve(new Err(Error("No token payload")));
}
resolve(decoded);

const payloadValidation = Auth0JwtPayloadSchema.decode(decoded);
if (isLeft(payloadValidation)) {
logger.error("Invalid token payload.");
return resolve(new Err(Error("Invalid token payload.")));
}

if (requiredScope && !useLegacy) {
const availableScopes = decoded.scope.split(" ");
if (!availableScopes.includes(requiredScope)) {
logger.error(
{ requiredScopes: requiredScope },
"Insufficient scopes."
);
return resolve(new Err(Error("Insufficient scopes.")));
}
}

return resolve(new Ok(payloadValidation.right));
}
);
});
}

/**
* Get a user resource from an Auth0 token.
* We return the user from its Auth0 sub, only if the token is not expired.
*/
export async function getUserFromAuth0Token(
accessToken: string
accessToken: Auth0JwtPayload
): Promise<UserResource | null> {
let decoded: jwt.JwtPayload;
try {
decoded = await verifyAuth0Token(accessToken);
} catch (error) {
logger.error({ error }, "Error verifying Auth0 token");
return null;
}

const now = Math.floor(Date.now() / 1000);

if (
typeof decoded.sub !== "string" ||
typeof decoded.exp !== "number" ||
decoded.exp <= now
) {
logger.error("Invalid or expired token payload.");
return null;
}

return UserResource.fetchByAuth0Sub(decoded.sub);
return UserResource.fetchByAuth0Sub(accessToken.sub);
}
3 changes: 3 additions & 0 deletions front/lib/api/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ const config = {
getAuth0TenantUrl: (): string => {
return EnvironmentConfig.getEnvVariable("AUTH0_TENANT_DOMAIN_URL");
},
getDustApiAudience: (): string => {
return EnvironmentConfig.getEnvVariable("DUST_API_AUDIENCE");
},
getAuth0M2MClientId: (): string => {
return EnvironmentConfig.getEnvVariable("AUTH0_M2M_CLIENT_ID");
},
Expand Down
50 changes: 45 additions & 5 deletions front/lib/api/wrappers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ import type {
} from "@dust-tt/types";
import type { NextApiRequest, NextApiResponse } from "next";

import { getUserFromAuth0Token } from "@app/lib/api/auth0";
import type { MethodType, ScopeType } from "@app/lib/api/auth0";
import {
getRequiredScope,
getUserFromAuth0Token,
verifyAuth0Token,
} from "@app/lib/api/auth0";
import { getUserWithWorkspaces } from "@app/lib/api/user";
import {
Authenticator,
Expand Down Expand Up @@ -47,7 +52,7 @@ export function withSessionAuthentication<T>(
api_error: {
type: "not_authenticated",
message:
"The user does not have an active session or is not authenticated",
"The user does not have an active session or is not authenticated.",
},
});
}
Expand Down Expand Up @@ -164,6 +169,7 @@ export function withPublicAPIAuthentication<T, U extends boolean>(
opts: {
isStreaming?: boolean;
allowUserOutsideCurrentWorkspace?: U;
requiredScopes?: Partial<Record<MethodType, ScopeType>>;
} = {}
) {
const { allowUserOutsideCurrentWorkspace, isStreaming } = opts;
Expand Down Expand Up @@ -201,8 +207,23 @@ export function withPublicAPIAuthentication<T, U extends boolean>(
// Authentification with Auth0 token.
// Straightforward since the token is attached to the user.
if (authMethod === "access_token") {
const auth = await Authenticator.fromAuth0Token({
const decoded = await verifyAuth0Token(
token,
getRequiredScope(req, opts.requiredScopes)
);
if (decoded.isErr()) {
return apiError(req, res, {
status_code: 401,
api_error: {
type: "not_authenticated",
message:
"The request does not have valid authentication credentials.",
},
});
}

const auth = await Authenticator.fromAuth0Token({
token: decoded.value,
wId,
});
if (auth.user() === null) {
Expand Down Expand Up @@ -322,7 +343,10 @@ export function withAuth0TokenAuthentication<T>(
req: NextApiRequest,
res: NextApiResponse<WithAPIErrorResponse<T>>,
user: UserTypeWithWorkspaces
) => Promise<void> | void
) => Promise<void> | void,
opts: {
requiredScopes?: Partial<Record<MethodType, ScopeType>>;
} = {}
) {
return withLogging(
async (
Expand Down Expand Up @@ -354,7 +378,23 @@ export function withAuth0TokenAuthentication<T>(
});
}

const user = await getUserFromAuth0Token(bearerToken);
const decoded = await verifyAuth0Token(
bearerToken,
getRequiredScope(req, opts.requiredScopes)
);
if (decoded.isErr()) {
return apiError(req, res, {
status_code: 401,
api_error: {
type: "not_authenticated",
message:
"The request does not have valid authentication credentials.",
},
});
}

const user = await getUserFromAuth0Token(decoded.value);
// TODO(thomas): user not found : means the user is not registered, display a message to the user and redirects to site
if (!user) {
return apiError(req, res, {
status_code: 401,
Expand Down
3 changes: 2 additions & 1 deletion front/lib/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import type {
NextApiResponse,
} from "next";

import type { Auth0JwtPayload } from "@app/lib/api/auth0";
import { getUserFromAuth0Token } from "@app/lib/api/auth0";
import config from "@app/lib/api/config";
import type { SessionWithUser } from "@app/lib/iam/provider";
Expand Down Expand Up @@ -299,7 +300,7 @@ export class Authenticator {
token,
wId,
}: {
token: string;
token: Auth0JwtPayload;
wId: string;
}): Promise<Authenticator> {
const user = await getUserFromAuth0Token(token);
Expand Down
4 changes: 3 additions & 1 deletion front/pages/api/v1/me.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,6 @@ async function handler(
}
}

export default withAuth0TokenAuthentication(handler);
export default withAuth0TokenAuthentication(handler, {
requiredScopes: { GET: "read:user_profile" },
});
4 changes: 3 additions & 1 deletion front/pages/api/v1/w/[wId]/assistant/agent_configurations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,6 @@ async function handler(
}
}

export default withPublicAPIAuthentication(handler);
export default withPublicAPIAuthentication(handler, {
requiredScopes: { GET: "read:agent" },
});
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,5 @@ async function handler(

export default withPublicAPIAuthentication(handler, {
isStreaming: true,
requiredScopes: { POST: "update:conversation" },
});
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,6 @@ async function handler(
}
}

export default withPublicAPIAuthentication(handler);
export default withPublicAPIAuthentication(handler, {
requiredScopes: { POST: "update:conversation" },
});
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,7 @@ async function handler(
}
}

export default withPublicAPIAuthentication(handler, { isStreaming: true });
export default withPublicAPIAuthentication(handler, {
isStreaming: true,
requiredScopes: { GET: "read:conversation" },
});
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,6 @@ async function handler(
}
}

export default withPublicAPIAuthentication(handler);
export default withPublicAPIAuthentication(handler, {
requiredScopes: { GET: "read:conversation" },
});
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,6 @@ async function handler(
}
}

export default withPublicAPIAuthentication(handler);
export default withPublicAPIAuthentication(handler, {
requiredScopes: { POST: "update:conversation" },
});
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,7 @@ async function handler(
}
}

export default withPublicAPIAuthentication(handler, { isStreaming: true });
export default withPublicAPIAuthentication(handler, {
isStreaming: true,
requiredScopes: { GET: "read:conversation" },
});
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,7 @@ async function handler(
}
}

export default withPublicAPIAuthentication(handler, { isStreaming: true });
export default withPublicAPIAuthentication(handler, {
isStreaming: true,
requiredScopes: { POST: "update:conversation" },
});
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,6 @@ async function handler(
}
}

export default withPublicAPIAuthentication(handler);
export default withPublicAPIAuthentication(handler, {
requiredScopes: { POST: "update:conversation" },
});
Loading

0 comments on commit 217bc8b

Please sign in to comment.