diff --git a/packages/api/src/controllers/experiment.ts b/packages/api/src/controllers/experiment.ts index d995f5a307..0b046b050b 100644 --- a/packages/api/src/controllers/experiment.ts +++ b/packages/api/src/controllers/experiment.ts @@ -2,11 +2,7 @@ import { Router } from "express"; import _ from "lodash"; import { v4 as uuid, validate as validateUuid } from "uuid"; import { db } from "../store"; -import { - NotFoundError, - ForbiddenError, - BadRequestError, -} from "../store/errors"; +import { NotFoundError, BadRequestError } from "../store/errors"; import { makeNextHREF, parseFilters, @@ -19,7 +15,7 @@ import { authorizer, validatePost } from "../middleware"; import { WithID } from "../store/types"; import experimentApis from "./experiment/index"; -import { isExperimentSubject } from "../store/experiment-table"; +import { ensureExperimentSubject } from "../store/experiment-table"; async function toUserId(emailOrId: string) { let user: User; @@ -42,10 +38,7 @@ const app = Router(); const experimentSubjectsOnly = (experiment: string) => async (req, res, next) => { - const isSubject = await isExperimentSubject(experiment, req.user?.id); - if (!isSubject) { - throw new ForbiddenError("user is not an experiment subject"); - } + await ensureExperimentSubject(experiment, req.user?.id); return next(); }; @@ -75,10 +68,7 @@ app.get("/check/:experiment", authorizer({}), async (req, res) => { } const { experiment: experimentQuery } = req.params; - const isSubject = await isExperimentSubject(experimentQuery, user.id); - if (!isSubject) { - throw new ForbiddenError("user is not an experiment subject"); - } + await ensureExperimentSubject(experimentQuery, user.id); res.status(204).end(); }); diff --git a/packages/api/src/store/experiment-table.ts b/packages/api/src/store/experiment-table.ts index 2b01932148..9c8b21646b 100644 --- a/packages/api/src/store/experiment-table.ts +++ b/packages/api/src/store/experiment-table.ts @@ -2,7 +2,7 @@ import sql from "sql-template-strings"; import { Experiment } from "../schema/types"; import db from "./db"; -import { NotFoundError } from "./errors"; +import { ForbiddenError, NotFoundError } from "./errors"; import Table from "./table"; import { WithID } from "./types"; @@ -11,6 +11,17 @@ export async function isExperimentSubject(experiment: string, userId: string) { return audienceUserIds.includes(userId); } +export async function ensureExperimentSubject( + experiment: string, + userId: string +) { + if (!(await isExperimentSubject(experiment, userId))) { + throw new ForbiddenError( + `user is not a subject of experiment: ${experiment}` + ); + } +} + export default class ExperimentTable extends Table> { async listUserExperiments( userId: string,