diff --git a/api/src/healthchecks/postgresCheck.ts b/api/src/healthchecks/postgresCheck.ts index f67f38cf..797a4eec 100644 --- a/api/src/healthchecks/postgresCheck.ts +++ b/api/src/healthchecks/postgresCheck.ts @@ -33,7 +33,7 @@ export async function postgresAvailableConnectionsCheck(dal: DataAccessLayer) { const check: any = { name: "@gram/api-postgres-available-connections", actionable: true, - healthy: dal.pool.waitingCount === 0, + healthy: dal.pool._pool.waitingCount === 0, dependentOn: "postgres", type: physical.type.EXTERNAL_DEPENDENCY, severity: physical.severity.WARNING, diff --git a/api/src/resources/gram/v1/models/create.spec.ts b/api/src/resources/gram/v1/models/create.spec.ts index 234d8fda..04cecccb 100644 --- a/api/src/resources/gram/v1/models/create.spec.ts +++ b/api/src/resources/gram/v1/models/create.spec.ts @@ -1,17 +1,9 @@ -import { jest } from "@jest/globals"; -import pg from "pg"; -import request from "supertest"; -import * as jwt from "@gram/core/dist/auth/jwt.js"; import { DataAccessLayer } from "@gram/core/dist/data/dal.js"; import Model, { Component } from "@gram/core/dist/data/models/Model.js"; import { _deleteAllTheThings } from "@gram/core/dist/data/utils.js"; +import request from "supertest"; import { createTestApp } from "../../../../test-util/app.js"; import { sampleOwnedSystem } from "../../../../test-util/sampleOwnedSystem.js"; -import { - sampleAdmin, - sampleOtherUser, - sampleUser, -} from "../../../../test-util/sampleUser.js"; import { sampleAdminToken, sampleOtherUserToken, @@ -23,7 +15,6 @@ describe("models.create", () => { const componentId2 = "fe93572e-9d0c-4afe-b042-e02c1c459999"; const dataFlowId = "fe93572e-9d0c-4afe-b042-e02c1cstonks"; let app: any; - let pool: pg.Pool; let dal: DataAccessLayer; let token = ""; let adminToken = ""; @@ -31,11 +22,11 @@ describe("models.create", () => { beforeAll(async () => { adminToken = await sampleAdminToken(); token = await sampleUserToken(); - ({ app, dal, pool } = await createTestApp()); + ({ app, dal } = await createTestApp()); }); beforeEach(async () => { - await _deleteAllTheThings(pool); + await _deleteAllTheThings(dal.pool); }); it("should return 401 on un-authenticated request", async () => { diff --git a/api/src/test-util/app.ts b/api/src/test-util/app.ts index 4ee22989..1285682a 100644 --- a/api/src/test-util/app.ts +++ b/api/src/test-util/app.ts @@ -4,5 +4,5 @@ import { bootstrap } from "@gram/core/dist/bootstrap.js"; export async function createTestApp() { const dal = await bootstrap(); const app = await createApp(dal); - return { pool: dal.pool, app, dal }; + return { pool: dal.pool._pool, app, dal }; } diff --git a/api/src/ws/index.ts b/api/src/ws/index.ts index 9cc61c40..cb76ebe6 100644 --- a/api/src/ws/index.ts +++ b/api/src/ws/index.ts @@ -101,6 +101,7 @@ export function attachWebsocketServer(server: Server, dal: DataAccessLayer) { if (!server) return; server.tellClientsToRefetch("threats", { modelId, componentId }); }); + dal.threatService.on("deleted-for", ({ modelId, componentId }) => { const server = wssRegistry.get(modelId); log.debug(`threat was deleted via api ${modelId} ${componentId}`); diff --git a/core/package.json b/core/package.json index 41eb5517..9eba19c1 100644 --- a/core/package.json +++ b/core/package.json @@ -41,7 +41,7 @@ "handlebars": "^4.7.7", "jsonwebtoken": "^9.0.0", "log4js": "^6.6.1", - "pg": "^8.11.1", + "pg": "^8.11.3", "postgres-migrations": "^5.3.0", "prom-client": "^14.0.1", "tslib": "^2.3.1" diff --git a/core/src/data/banners/BannerDataService.ts b/core/src/data/banners/BannerDataService.ts index dbdaaff6..c657f2ff 100644 --- a/core/src/data/banners/BannerDataService.ts +++ b/core/src/data/banners/BannerDataService.ts @@ -1,6 +1,6 @@ -import pg from "pg"; import log4js from "log4js"; import { DataAccessLayer } from "../dal.js"; +import { GramConnectionPool } from "../postgres.js"; interface Banner { id: number; @@ -10,7 +10,11 @@ interface Banner { } export class BannerDataService { - constructor(private pool: pg.Pool, private dal: DataAccessLayer) {} + constructor(private dal: DataAccessLayer) { + this.pool = dal.pool; + } + + private pool: GramConnectionPool; log = log4js.getLogger("BannerDataService"); diff --git a/core/src/data/controls/ControlDataService.spec.ts b/core/src/data/controls/ControlDataService.spec.ts index 4f6d0588..37d7de68 100644 --- a/core/src/data/controls/ControlDataService.spec.ts +++ b/core/src/data/controls/ControlDataService.spec.ts @@ -12,25 +12,23 @@ import { ControlDataService } from "./ControlDataService.js"; describe("ControlDataService implementation", () => { let data: ControlDataService; let dal: DataAccessLayer; - let pool: pg.Pool; let model: Model; beforeAll(async () => { - pool = await createPostgresPool(); + const pool = await createPostgresPool(); dal = new DataAccessLayer(pool); - data = new ControlDataService(pool, dal); - await _deleteAllTheThings(pool); + data = new ControlDataService(dal); }); beforeEach(async () => { - await _deleteAllTheThings(pool); + await _deleteAllTheThings(dal); model = new Model("some-system-id", "some-version", "root"); model.data = { components: [], dataFlows: [] }; model.id = await dal.modelService.create(model); }); afterAll(async () => { - await pool.end(); + await dal.pool.end(); }); describe("getById control", () => { diff --git a/core/src/data/controls/ControlDataService.ts b/core/src/data/controls/ControlDataService.ts index 3ab2fb46..0e276550 100644 --- a/core/src/data/controls/ControlDataService.ts +++ b/core/src/data/controls/ControlDataService.ts @@ -5,6 +5,7 @@ import { SuggestionID } from "../../suggestions/models.js"; import { DataAccessLayer } from "../dal.js"; import { SuggestionStatus } from "../suggestions/Suggestion.js"; import Control from "./Control.js"; +import { GramConnectionPool } from "../postgres.js"; export function convertToControl(row: any) { const control = new Control( @@ -23,15 +24,13 @@ export function convertToControl(row: any) { } export class ControlDataService extends EventEmitter { - constructor(pool: pg.Pool, dal: DataAccessLayer) { + constructor(private dal: DataAccessLayer) { super(); - this.pool = pool; - this.dal = dal; + this.pool = dal.pool; this.log = log4js.getLogger("ControlDataService"); } - private pool: pg.Pool; - private dal: DataAccessLayer; + private pool: GramConnectionPool; log: any; /** * Create a control object of specified id @@ -173,12 +172,9 @@ export class ControlDataService extends EventEmitter { WHERE c.model_id = $1::uuid AND m.control_id = c.id AND c.id ${filter} `; - const client = await this.pool.connect(); - let result = false; - try { - await client.query("BEGIN"); + const result = await this.pool.runTransaction(async (client) => { const res = await client.query(query, [modelId, ...ids]); - result = res.rowCount > 0; + const result = res.rowCount > 0; if (result) { const suggestionIds = res.rows @@ -200,13 +196,9 @@ export class ControlDataService extends EventEmitter { componentId: res.rows[0].component_id, }); } - await client.query("COMMIT"); - } catch (e) { - this.log.error("Failed to delete control", e); - await client.query("ROLLBACK"); - } finally { - client.release(); - } + + return result; + }); return result; } diff --git a/core/src/data/dal.ts b/core/src/data/dal.ts index d6f155f1..42e7190d 100644 --- a/core/src/data/dal.ts +++ b/core/src/data/dal.ts @@ -19,14 +19,18 @@ import { authzProvider } from "../auth/authorization.js"; import { systemProvider } from "./systems/systems.js"; import { SystemProvider } from "./systems/SystemProvider.js"; import { ReviewerHandler } from "./reviews/ReviewerHandler.js"; -import { createPostgresPool, getDatabaseName } from "./postgres.js"; +import { + GramConnectionPool, + createPostgresPool, + getDatabaseName, +} from "./postgres.js"; /** * Class that carries access to all DataServices, useful for passing dependencies. */ export class DataAccessLayer { // Database Connection Pool for direct access to postgres - pool: pg.Pool; + pool: GramConnectionPool; // DataServices - specific logic to handle database interactions modelService: ModelDataService; @@ -62,7 +66,7 @@ export class DataAccessLayer { } constructor(pool: pg.Pool) { - this.pool = pool; + this.pool = new GramConnectionPool(pool); this.sysPropHandler = new SystemPropertyHandler(); this.ccHandler = new ComponentClassHandler(); this.templateHandler = new TemplateHandler(); @@ -71,15 +75,15 @@ export class DataAccessLayer { this.reviewerHandler = new ReviewerHandler(); // Initialize Data Services - this.modelService = new ModelDataService(pool, this); - this.controlService = new ControlDataService(pool, this); - this.threatService = new ThreatDataService(pool, this); - this.mitigationService = new MitigationDataService(pool); - this.notificationService = new NotificationDataService(pool, this); - this.reviewService = new ReviewDataService(pool, this); - this.suggestionService = new SuggestionDataService(pool, this); + this.modelService = new ModelDataService(this); + this.controlService = new ControlDataService(this); + this.threatService = new ThreatDataService(this); + this.mitigationService = new MitigationDataService(this); + this.notificationService = new NotificationDataService(this); + this.reviewService = new ReviewDataService(this); + this.suggestionService = new SuggestionDataService(this); this.suggestionEngine = new SuggestionEngine(this); - this.reportService = new ReportDataService(pool, this); - this.bannerService = new BannerDataService(pool, this); + this.reportService = new ReportDataService(this); + this.bannerService = new BannerDataService(this); } } diff --git a/core/src/data/mitigations/MitigationDataService.spec.ts b/core/src/data/mitigations/MitigationDataService.spec.ts index 9d13749c..4535dd7e 100644 --- a/core/src/data/mitigations/MitigationDataService.spec.ts +++ b/core/src/data/mitigations/MitigationDataService.spec.ts @@ -1,5 +1,4 @@ import { randomUUID } from "crypto"; -import pg from "pg"; import Control from "../controls/Control.js"; import { DataAccessLayer } from "../dal.js"; import Model from "../models/Model.js"; @@ -12,25 +11,24 @@ import { MitigationDataService } from "./MitigationDataService.js"; describe("MitigationDataService implementation", () => { let data: MitigationDataService; let dal: DataAccessLayer; - let pool: pg.Pool; + let model: Model; beforeAll(async () => { - pool = await createPostgresPool(); - data = new MitigationDataService(pool); + const pool = await createPostgresPool(); dal = new DataAccessLayer(pool); - await _deleteAllTheThings(pool); + data = new MitigationDataService(dal); }); beforeEach(async () => { - await _deleteAllTheThings(pool); + await _deleteAllTheThings(dal); model = new Model("some-system-id", "some-version", "root"); model.data = { components: [], dataFlows: [] }; model.id = await dal.modelService.create(model); }); afterAll(async () => { - await pool.end(); + await dal.pool.end(); }); describe("getById mitigation", () => { diff --git a/core/src/data/mitigations/MitigationDataService.ts b/core/src/data/mitigations/MitigationDataService.ts index 690e9774..6c81931b 100644 --- a/core/src/data/mitigations/MitigationDataService.ts +++ b/core/src/data/mitigations/MitigationDataService.ts @@ -1,6 +1,7 @@ -import pg from "pg"; -import { EventEmitter } from "node:events"; import log4js from "log4js"; +import { EventEmitter } from "node:events"; +import { DataAccessLayer } from "../dal.js"; +import { GramConnectionPool } from "../postgres.js"; import Mitigation from "./Mitigation.js"; function convertToMitigation(row: any) { @@ -15,12 +16,12 @@ function convertToMitigation(row: any) { } export class MitigationDataService extends EventEmitter { - constructor(pool: pg.Pool) { + constructor(private dal: DataAccessLayer) { super(); - this.pool = pool; + this.pool = dal.pool; this.log = log4js.getLogger("MitigationDataService"); } - private pool: pg.Pool; + private pool: GramConnectionPool; log: any; /** @@ -45,26 +46,17 @@ export class MitigationDataService extends EventEmitter { WHERE id = $1::uuid `; - const client = await this.pool.connect(); - try { - await client.query("BEGIN"); + await this.pool.runTransaction(async (client) => { await client.query(query, [threatId, controlId, createdBy]); const res_threats = await client.query(queryThreats, [threatId]); - await client.query("COMMIT"); this.emit("updated-for", { modelId: res_threats.rows[0].model_id, componentId: res_threats.rows[0].component_id, }); - return { threatId, controlId }; - } catch (e) { - await client.query("ROLLBACK"); - this.log.error("Failed to create mitigation", e); - } finally { - client.release(); - } + }); - return false; + return { threatId, controlId }; } /** diff --git a/core/src/data/models/ModelDataService.spec.ts b/core/src/data/models/ModelDataService.spec.ts index 19555aec..cb48431c 100644 --- a/core/src/data/models/ModelDataService.spec.ts +++ b/core/src/data/models/ModelDataService.spec.ts @@ -17,21 +17,19 @@ import { sampleUser } from "../../test-util/sampleUser.js"; describe("ModelDataService implementation", () => { let data: ModelDataService; let dal: DataAccessLayer; - let pool: pg.Pool; beforeAll(async () => { - pool = await createPostgresPool(); + const pool = await createPostgresPool(); dal = new DataAccessLayer(pool); - data = new ModelDataService(pool, dal); - await _deleteAllTheThings(pool); + data = new ModelDataService(dal); }); afterEach(async () => { - await _deleteAllTheThings(pool); + await _deleteAllTheThings(dal); }); afterAll(async () => { - await pool.end(); + await dal.pool.end(); }); describe("getById model", () => { @@ -461,7 +459,6 @@ describe("ModelDataService implementation", () => { const mitigationsCopy = await dal.mitigationService.list( resModelCopy!.id! ); - console.log(mitigationsCopy); expect(mitigationsCopy.length).toEqual(3); diff --git a/core/src/data/models/ModelDataService.ts b/core/src/data/models/ModelDataService.ts index 785ad7e1..420ef8e6 100644 --- a/core/src/data/models/ModelDataService.ts +++ b/core/src/data/models/ModelDataService.ts @@ -5,10 +5,10 @@ */ import { randomUUID } from "crypto"; -import { EventEmitter } from "node:events"; -import pg from "pg"; import log4js from "log4js"; +import { EventEmitter } from "node:events"; import { DataAccessLayer } from "../dal.js"; +import { GramConnectionPool } from "../postgres.js"; import Model, { ModelData } from "./Model.js"; function convertToModel(row: any) { @@ -37,13 +37,13 @@ export interface ModelListOptions { } export class ModelDataService extends EventEmitter { - constructor(pool: pg.Pool, private dal: DataAccessLayer) { + constructor(private dal: DataAccessLayer) { super(); - this.pool = pool; + this.pool = dal.pool; this.log = log4js.getLogger("ModelDataService"); } - private pool: pg.Pool; + private pool: GramConnectionPool; log: any; /** @@ -291,10 +291,7 @@ export class ModelDataService extends EventEmitter { AND deleted_at IS NULL; `; - const client = await this.pool.connect(); - try { - await client.query("BEGIN"); - + await this.pool.runTransaction(async (client) => { for (const threat of threats) { uuid.set(threat.id!, randomUUID()); await client.query(queryThreats, [ @@ -331,14 +328,10 @@ export class ModelDataService extends EventEmitter { mitigation.controlId, ]); } - await client.query("COMMIT"); - this.emit("updated-for", { modelId: uuid.get(srcModel.id!) }); - } catch (e) { - await client.query("ROLLBACK"); - this.log.error("Failed to copy model", e); - } finally { - client.release(); - } + }); + + this.emit("updated-for", { modelId: uuid.get(srcModel.id!) }); + return uuid.get(srcModel.id!) as string; } @@ -375,10 +368,7 @@ export class ModelDataService extends EventEmitter { RETURNING id `; let success = false; - const client = await this.pool.connect(); - try { - await client.query("BEGIN"); - + const res = await this.pool.runTransaction(async (client) => { const threats = await client.query(queryThreats, [id]); const controls = await client.query(queryControls, [id]); @@ -388,17 +378,14 @@ export class ModelDataService extends EventEmitter { await client.query(queryMitigations, [threatIds, controlIds]); const res = await client.query(query, [id]); - await client.query("COMMIT"); - if (res.rowCount > 0) { - this.emit("updated-for", { modelId: id }); - success = true; - } - } catch (e) { - await client.query("ROLLBACK"); - this.log.error("Failed to delete model", e); - } finally { - client.release(); + return res; + }); + + if (res.rowCount > 0) { + this.emit("updated-for", { modelId: id }); + success = true; } + return success; } @@ -481,19 +468,15 @@ export class ModelDataService extends EventEmitter { ); `; - const client = await this.pool.connect(); - let insertRes = null; - try { - await client.query("BEGIN"); - insertRes = await client.query(insertQuery, [userId, modelId, action]); + const insertRes = await this.pool.runTransaction(async (client) => { + const insertRes = await client.query(insertQuery, [ + userId, + modelId, + action, + ]); await client.query(deleteQuery, [userId]); - await client.query("COMMIT"); - } catch (e) { - await client.query("ROLLBACK"); - this.log.error("Failed to log action", e); - } finally { - client.release(); - } + return insertRes; + }); return insertRes?.rows[0]?.id; } diff --git a/core/src/data/notifications/NotificationDataService.spec.ts b/core/src/data/notifications/NotificationDataService.spec.ts index c58dccd1..e08cae9f 100644 --- a/core/src/data/notifications/NotificationDataService.spec.ts +++ b/core/src/data/notifications/NotificationDataService.spec.ts @@ -29,23 +29,22 @@ const sampleNotification = new PlaintextHandlebarsNotificationTemplate( ); describe("NotificationDataService implementation", () => { - let pool: pg.Pool; let dal: DataAccessLayer; let data: NotificationDataService; beforeAll(async () => { - pool = await createPostgresPool(); + const pool = await createPostgresPool(); dal = new DataAccessLayer(pool); dal.templateHandler.register(sampleNotification); - data = new NotificationDataService(pool, dal); + data = new NotificationDataService(dal); }); beforeEach(async () => { - await _deleteAllTheThings(pool); + await _deleteAllTheThings(dal); }); afterAll(async () => { - await pool.end(); + await dal.pool.end(); }); describe("queue", () => { diff --git a/core/src/data/notifications/NotificationDataService.ts b/core/src/data/notifications/NotificationDataService.ts index 36f2f85e..54d4af8a 100644 --- a/core/src/data/notifications/NotificationDataService.ts +++ b/core/src/data/notifications/NotificationDataService.ts @@ -1,7 +1,7 @@ -import { Notification, NotificationStatus } from "./Notification.js"; -import pg from "pg"; import log4js from "log4js"; import { DataAccessLayer } from "../dal.js"; +import { GramConnectionPool } from "../postgres.js"; +import { Notification, NotificationStatus } from "./Notification.js"; import { NotificationInput } from "./NotificationInput.js"; function convertToNotification(row: any) { @@ -16,7 +16,11 @@ function convertToNotification(row: any) { } export class NotificationDataService { - constructor(private pool: pg.Pool, private dal: DataAccessLayer) {} + constructor(private dal: DataAccessLayer) { + this.pool = dal.pool; + } + + private pool: GramConnectionPool; log = log4js.getLogger("NotificationDataService"); diff --git a/core/src/data/postgres.spec.ts b/core/src/data/postgres.spec.ts index 1f0bc726..19eec770 100644 --- a/core/src/data/postgres.spec.ts +++ b/core/src/data/postgres.spec.ts @@ -16,7 +16,7 @@ describe("postgres pool", () => { try { await pool.query("invalid sql :)"); } catch (err) { - console.log(err); + // console.log(err); errd = true; } expect(errd).toBeTruthy(); @@ -59,7 +59,7 @@ describe("postgres pool", () => { try { await pool.runTransaction(async (client) => { try { - console.log("during transaction"); + // console.log("during transaction"); await client.query("SELECT pg_sleep(5)"); } catch (err) { // The error will be handled here first, and then the transaction will continue. @@ -72,7 +72,7 @@ describe("postgres pool", () => { // } catch (err) { // log stuff, think problem is fixed } // client.query("next part of transaction") -> error now because client is still broken. // } - console.log("transaction err", err); + // console.log("transaction err", err); expect(err).toBeTruthy(); } }); diff --git a/core/src/data/postgres.ts b/core/src/data/postgres.ts index cb29a7d2..9b21ba34 100644 --- a/core/src/data/postgres.ts +++ b/core/src/data/postgres.ts @@ -23,13 +23,18 @@ export class GramConnectionPool { } async query(query: string, ...args: any[]) { - log.info(args); + log.debug(query); return this._pool.query(query, ...args); } async runTransaction(transaction: Transaction): Promise { const client = await this._pool.connect(); + client.on("error", (err) => { + // This *should* catch weird timeout/disconnects that may happen + log.error("Transaction error", err); + }); + try { // Do transaction stuff await client.query("BEGIN"); @@ -99,6 +104,7 @@ export async function createPostgresPool(passedOpts?: pg.PoolConfig) { process.env.NODE_ENV && ["test"].includes(process.env.NODE_ENV) ? 0 : 5000, + idleTimeoutMillis: 1000, }; defaultOpts.host = await config.postgres.host.getValue(); @@ -135,9 +141,6 @@ export async function createPostgresPool(passedOpts?: pg.PoolConfig) { log.error("Pool error", err); }); - // TODO: will refactor DAL and more to use this wrapper class. - // return new GramConnectionPool(pool); - // TODO: figure out metrics for multiple pools... initPostgresMetrics(pool); diff --git a/core/src/data/reports/ReportDataService.spec.ts b/core/src/data/reports/ReportDataService.spec.ts new file mode 100644 index 00000000..9755f176 --- /dev/null +++ b/core/src/data/reports/ReportDataService.spec.ts @@ -0,0 +1,25 @@ +import { DataAccessLayer } from "../dal.js"; +import { createPostgresPool } from "../postgres.js"; +import { _deleteAllTheThings } from "../utils.js"; + +describe("ReportDataService implementation", () => { + let dal: DataAccessLayer; + + beforeAll(async () => { + const pool = await createPostgresPool(); + dal = new DataAccessLayer(pool); + }); + + afterAll(async () => { + await _deleteAllTheThings(dal); + await dal.pool.end(); + }); + + describe("listSystemCompliance", () => { + it("should not crash", async () => { + const report = await dal.reportService.listSystemCompliance({}); + + expect(report.TotalSystems).toBe("0"); + }); + }); +}); diff --git a/core/src/data/reports/ReportDataService.ts b/core/src/data/reports/ReportDataService.ts index 79b93021..eefda8ef 100644 --- a/core/src/data/reports/ReportDataService.ts +++ b/core/src/data/reports/ReportDataService.ts @@ -4,6 +4,7 @@ import { linkToModel } from "../../util/links.js"; import { DataAccessLayer } from "../dal.js"; import { SystemPropertyValue } from "../system-property/types.js"; import { RequestContext } from "../providers/RequestContext.js"; +import { GramConnectionPool } from "../postgres.js"; interface SystemCompliance { SystemID: string; @@ -27,7 +28,11 @@ interface SystemComplianceReport { } export class ReportDataService { - constructor(private pool: pg.Pool, private dal: DataAccessLayer) {} + constructor(private dal: DataAccessLayer) { + this.pool = dal.pool; + } + + private pool: GramConnectionPool; log = log4js.getLogger("ReportDataService"); diff --git a/core/src/data/reviews/ReviewDataService.spec.ts b/core/src/data/reviews/ReviewDataService.spec.ts index da18cee1..eab142e9 100644 --- a/core/src/data/reviews/ReviewDataService.spec.ts +++ b/core/src/data/reviews/ReviewDataService.spec.ts @@ -1,33 +1,31 @@ import { jest } from "@jest/globals"; -import pg from "pg"; import { randomUUID } from "crypto"; +import { SpiedFunction } from "jest-mock"; +import { createSampleModel } from "../../test-util/model.js"; +import { testReviewerProvider } from "../../test-util/sampleReviewer.js"; import { DataAccessLayer } from "../dal.js"; import { createPostgresPool } from "../postgres.js"; import { _deleteAllTheThings } from "../utils.js"; import { Review, ReviewStatus } from "./Review.js"; import { ReviewDataService } from "./ReviewDataService.js"; -import { createSampleModel } from "../../test-util/model.js"; -import { testReviewerProvider } from "../../test-util/sampleReviewer.js"; -import { SpiedFunction, SpyInstance } from "jest-mock"; describe("ReviewDataService implementation", () => { - let pool: pg.Pool; let dal: DataAccessLayer; let data: ReviewDataService; let modelId: string; let notificationQueue: SpiedFunction; beforeAll(async () => { - pool = await createPostgresPool(); + const pool = await createPostgresPool(); dal = new DataAccessLayer(pool); - data = new ReviewDataService(pool, dal); + data = new ReviewDataService(dal); notificationQueue = jest.spyOn(dal.notificationService, "queue"); await _deleteAllTheThings(pool); dal.reviewerHandler.setReviewerProvider(testReviewerProvider); }); beforeEach(async () => { - await _deleteAllTheThings(pool); + await _deleteAllTheThings(dal); /** Set up test model needed for review **/ modelId = await createSampleModel(dal); @@ -39,7 +37,7 @@ describe("ReviewDataService implementation", () => { afterAll(async () => { notificationQueue.mockRestore(); - await pool.end(); + await dal.pool.end(); }); describe("getByModelId", () => { diff --git a/core/src/data/reviews/ReviewDataService.ts b/core/src/data/reviews/ReviewDataService.ts index 66a17c97..8ff281dc 100644 --- a/core/src/data/reviews/ReviewDataService.ts +++ b/core/src/data/reviews/ReviewDataService.ts @@ -9,6 +9,7 @@ import { } from "../system-property/types.js"; import { Review, ReviewStatus } from "./Review.js"; import { ReviewSystemCompliance } from "./ReviewSystemCompliance.js"; +import { GramConnectionPool } from "../postgres.js"; export function convertToReview(row: any): Review { const review = new Review(row.model_id, row.requested_by, row.status); @@ -52,10 +53,13 @@ interface ReviewListResult { } export class ReviewDataService extends EventEmitter { - constructor(private pool: pg.Pool, private dal: DataAccessLayer) { + constructor(private dal: DataAccessLayer) { super(); + this.pool = dal.pool; } + private pool: GramConnectionPool; + log = log4js.getLogger("ReviewDataService"); /** diff --git a/core/src/data/suggestions/SuggestionDataService.spec.ts b/core/src/data/suggestions/SuggestionDataService.spec.ts index 454e0edf..25ff5660 100644 --- a/core/src/data/suggestions/SuggestionDataService.spec.ts +++ b/core/src/data/suggestions/SuggestionDataService.spec.ts @@ -21,28 +21,21 @@ import { } from "./Suggestion.js"; describe("SuggestionDataService implementation", () => { - let pool: pg.Pool; let dal: DataAccessLayer; - let modelId: string; beforeAll(async () => { - pool = await createPostgresPool(); + const pool = await createPostgresPool(); dal = new DataAccessLayer(pool); }); - beforeEach(async () => { - await _deleteAllTheThings(pool); - - /** Set up test model needed for review **/ - modelId = await createSampleModel(dal); - }); - afterAll(async () => { - await pool.end(); + await _deleteAllTheThings(dal); + await dal.pool.end(); }); describe("bulkInsert", () => { it("should be able to insert empty threats and controls", async () => { + const modelId = await createSampleModel(dal); const suggestions: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [], @@ -52,6 +45,7 @@ describe("SuggestionDataService implementation", () => { }); it("should be able to insert multiple threats and controls", async () => { + const modelId = await createSampleModel(dal); const suggestions: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [...new Array(50)].map(genSuggestedControl), @@ -68,6 +62,7 @@ describe("SuggestionDataService implementation", () => { }); it("should remove unused suggestions that are no longer included in the batch, but keep ones that have been added/rejected", async () => { + const modelId = await createSampleModel(dal); const suggestions: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [...new Array(50)].map(genSuggestedControl), @@ -118,6 +113,7 @@ describe("SuggestionDataService implementation", () => { }); it("should not have different sources interfering with each others' batches", async () => { + const modelId = await createSampleModel(dal); const suggestions: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [...new Array(50)].map(genSuggestedControl), @@ -140,6 +136,7 @@ describe("SuggestionDataService implementation", () => { }); it("should not have different models interfering with each others' batches", async () => { + const modelId = await createSampleModel(dal); const suggestions: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [...new Array(3)].map(genSuggestedControl), @@ -167,6 +164,7 @@ describe("SuggestionDataService implementation", () => { }); it("should insert control suggestions with empty mitigations", async () => { + const modelId = await createSampleModel(dal); const suggestionsAfter: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [genSuggestedControl()], @@ -187,6 +185,7 @@ describe("SuggestionDataService implementation", () => { }); it("should insert control suggestions with list of mitigations", async () => { + const modelId = await createSampleModel(dal); const suggestThreats = [...Array(10)].map(genSuggestedThreat); const partialThreatIds = suggestThreats.map((t: any) => t.id.val.split("/").slice(1).join("/") @@ -219,6 +218,7 @@ describe("SuggestionDataService implementation", () => { describe("listControlSuggestions", () => { it("should return an empty list if no suggestions", async () => { + const modelId = await createSampleModel(dal); const suggestions: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [], @@ -231,6 +231,7 @@ describe("SuggestionDataService implementation", () => { }); it("should return a list", async () => { + const modelId = await createSampleModel(dal); const suggestions: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [...new Array(50)].map(() => genSuggestedControl()), @@ -246,6 +247,7 @@ describe("SuggestionDataService implementation", () => { }); it("should return the correct list", async () => { + const modelId = await createSampleModel(dal); const suggestions: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [...new Array(50)].map(() => genSuggestedControl()), @@ -266,6 +268,7 @@ describe("SuggestionDataService implementation", () => { describe("listThreatSuggestions", () => { it("should return an empty list if no suggestions", async () => { + const modelId = await createSampleModel(dal); const suggestions: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [], @@ -278,6 +281,7 @@ describe("SuggestionDataService implementation", () => { }); it("should return a list", async () => { + const modelId = await createSampleModel(dal); const suggestions: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [], @@ -295,6 +299,7 @@ describe("SuggestionDataService implementation", () => { describe("acceptSuggestion", () => { it("should return false if suggestion not found", async () => { + const modelId = await createSampleModel(dal); const res = await dal.suggestionService.acceptSuggestion( modelId, new SuggestionID(`${randomUUID()}/test-source/threat/test-1-23`), @@ -302,15 +307,19 @@ describe("SuggestionDataService implementation", () => { ); expect(res).toBe(false); }); + it("should return true if suggestion is control or threat", async () => { + const modelId = await createSampleModel(dal); const suggestThreat = genSuggestedThreat(); const suggestControl = genSuggestedControl({ + componentId: randomUUID(), mitigates: [ { partialThreatId: suggestThreat.id.val.split("/").slice(1).join("/"), }, ], }); + const suggestions: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [suggestControl], @@ -320,18 +329,19 @@ describe("SuggestionDataService implementation", () => { let res = await dal.suggestionService.acceptSuggestion( modelId, - suggestions.controls[0].id, + suggestControl.id, "someuser" ); expect(res).toBe(true); let suggestion = (await dal.suggestionService.getById( modelId, - suggestions.controls[0].id + suggestControl.id )) as SuggestedControl | SuggestedThreat; - expect(suggestion.status).toBe(SuggestionStatus.Accepted); + + expect(suggestion.status).toEqual(SuggestionStatus.Accepted); const control = await dal.suggestionService._getLinkedThreatOrControl( modelId, - suggestions.controls[0].id + suggestControl.id ); expect(control.title).toEqual(suggestion.title); expect(control.description).toEqual(suggestion.description); @@ -339,18 +349,20 @@ describe("SuggestionDataService implementation", () => { // At this point, no mitigation should exist. res = await dal.suggestionService.acceptSuggestion( modelId, - suggestions.threats[0].id, + suggestThreat.id, "someuser" ); expect(res).toBe(true); suggestion = (await dal.suggestionService.getById( modelId, - suggestions.threats[0].id + suggestThreat.id )) as SuggestedThreat; - expect(suggestion.status).toBe(SuggestionStatus.Accepted); + + expect(suggestion.status).toEqual(SuggestionStatus.Accepted); }); it("should create mitigation(s) if relevant threat exists", async () => { + const modelId = await createSampleModel(dal); const suggestThreats = [...Array(5)].map(genSuggestedThreat); const suggestControl = genSuggestedControl({ mitigates: suggestThreats.map((suggestThreat) => ({ @@ -404,6 +416,7 @@ describe("SuggestionDataService implementation", () => { }); it("should NOT create mitigation if relevant threat does NOT exists", async () => { + const modelId = await createSampleModel(dal); const suggestThreat = genSuggestedThreat(); const suggestControl = genSuggestedControl({ mitigates: [ @@ -453,6 +466,7 @@ describe("SuggestionDataService implementation", () => { }); it("should NOT list deleted threats from partialId", async () => { + const modelId = await createSampleModel(dal); const suggestThreat = genSuggestedThreat(); const suggestControl = genSuggestedControl({ mitigates: [ @@ -489,6 +503,7 @@ describe("SuggestionDataService implementation", () => { describe("setSuggestionStatus", () => { it("should return false if suggestion not found", async () => { + const modelId = await createSampleModel(dal); const res = await dal.suggestionService.setSuggestionStatus( modelId, new SuggestionID(`${randomUUID()}/test-source/threat/test-1-23`), @@ -498,6 +513,7 @@ describe("SuggestionDataService implementation", () => { }); it("should be able to set all statuses", async () => { + const modelId = await createSampleModel(dal); const suggestions: EngineSuggestedResult = { sourceSlugToClear: "test", controls: [genSuggestedControl()], @@ -505,11 +521,11 @@ describe("SuggestionDataService implementation", () => { }; await dal.suggestionService.bulkInsert(modelId, suggestions); - [ + for (const status of [ SuggestionStatus.Accepted, SuggestionStatus.New, SuggestionStatus.Rejected, - ].forEach(async (status) => { + ]) { let res = await dal.suggestionService.setSuggestionStatus( modelId, suggestions.threats[0].id, @@ -531,7 +547,7 @@ describe("SuggestionDataService implementation", () => { modelId ); expect(threats[0].status).toBe(status); - }); + } }); }); }); diff --git a/core/src/data/suggestions/SuggestionDataService.ts b/core/src/data/suggestions/SuggestionDataService.ts index 819c4e89..53a174d6 100644 --- a/core/src/data/suggestions/SuggestionDataService.ts +++ b/core/src/data/suggestions/SuggestionDataService.ts @@ -17,6 +17,7 @@ import { SuggestedThreat, SuggestionStatus, } from "./Suggestion.js"; +import { GramConnectionPool } from "../postgres.js"; function convertToSuggestionControl(row: any) { const control = new SuggestedControl( @@ -50,10 +51,13 @@ function convertToSuggestionThreat(row: any) { const log = log4js.getLogger("SuggestionDataService"); export class SuggestionDataService extends EventEmitter { - constructor(private pool: Pool, private dal: DataAccessLayer) { + constructor(private dal: DataAccessLayer) { super(); + this.pool = dal.pool; } + private pool: GramConnectionPool; + /** * Copy suggestions from one model to anothger * @param fromModelId @@ -128,11 +132,7 @@ export class SuggestionDataService extends EventEmitter { DELETE FROM suggested_controls WHERE source = $1::varchar and model_id = $2::uuid and status = 'new'; `; - const client = await this.pool.connect(); - - try { - await client.query("BEGIN"); - + await this.pool.runTransaction(async (client) => { // Clear previous batches from this source if (suggestions.sourceSlugToClear) { await client.query(deleteControlsQuery, [ @@ -179,19 +179,13 @@ export class SuggestionDataService extends EventEmitter { } const queries = bulkThreats.concat(bulkControls); await Promise.all(queries); - await client.query("COMMIT"); log.debug( `inserted ${bulkThreats.length} suggested threats, ${bulkControls.length} suggested controls.` ); this.emit("updated-for", { modelId, }); - } catch (e) { - await client.query("ROLLBACK"); - log.error("Failed to insert suggestions", e); - } finally { - client.release(); - } + }); } async listControlSuggestions(modelId: string) { diff --git a/core/src/data/threats/ThreatDataService.spec.ts b/core/src/data/threats/ThreatDataService.spec.ts index b12c7c86..7fe0ad77 100644 --- a/core/src/data/threats/ThreatDataService.spec.ts +++ b/core/src/data/threats/ThreatDataService.spec.ts @@ -12,25 +12,23 @@ import { ThreatDataService } from "./ThreatDataService.js"; describe("ThreatDataService implementation", () => { let data: ThreatDataService; let dal: DataAccessLayer; - let pool: pg.Pool; let model: Model; beforeAll(async () => { - pool = await createPostgresPool(); + const pool = await createPostgresPool(); dal = new DataAccessLayer(pool); data = dal.threatService; - await _deleteAllTheThings(pool); }); beforeEach(async () => { - await _deleteAllTheThings(pool); + await _deleteAllTheThings(dal); model = new Model("some-system-id", "some-version", "root"); model.data = { components: [], dataFlows: [] }; model.id = await dal.modelService.create(model); }); afterAll(async () => { - await pool.end(); + await dal.pool.end(); }); describe("getById threat", () => { diff --git a/core/src/data/threats/ThreatDataService.ts b/core/src/data/threats/ThreatDataService.ts index a93e811c..1adc2b37 100644 --- a/core/src/data/threats/ThreatDataService.ts +++ b/core/src/data/threats/ThreatDataService.ts @@ -5,10 +5,10 @@ */ import { EventEmitter } from "events"; -import pg from "pg"; import log4js from "log4js"; import { SuggestionID } from "../../suggestions/models.js"; import { DataAccessLayer } from "../dal.js"; +import { GramConnectionPool } from "../postgres.js"; import { SuggestionStatus } from "../suggestions/Suggestion.js"; import Threat, { ThreatSeverity } from "./Threat.js"; @@ -30,15 +30,13 @@ export function convertToThreat(row: any): Threat { } export class ThreatDataService extends EventEmitter { - constructor(pool: pg.Pool, dal: DataAccessLayer) { + constructor(private dal: DataAccessLayer) { super(); - this.pool = pool; - this.dal = dal; + this.pool = dal.pool; this.log = log4js.getLogger("ThreatDataService"); } - private pool: pg.Pool; - private dal: DataAccessLayer; + private pool: GramConnectionPool; log: any; /** @@ -262,41 +260,39 @@ export class ThreatDataService extends EventEmitter { WHERE m.threat_id = t.id AND t.model_id = $1::uuid and t.id ${filter} `; - const client = await this.pool.connect(); - let result = false; - try { - await client.query("BEGIN"); - const res = await client.query(query, [modelId, ...ids]); - result = res.rowCount > 0; - - if (result) { - const suggestionIds = res.rows - .filter((v: any) => v.suggestion_id) - .map((v: any) => new SuggestionID(v.suggestion_id)); - - // This runs in a different client and could be problematic. - const promises = suggestionIds.map((id) => - this.dal.suggestionService.setSuggestionStatus( - res.rows[0].model_id, - id, - SuggestionStatus.New - ) - ); - await Promise.all(promises); - - await client.query(queryMitigations, [modelId, ...ids]); - this.emit("deleted-for", { - modelId: res.rows[0].model_id, - componentId: res.rows[0].component_id, - }); + const [result, suggestionIds] = await this.pool.runTransaction( + async (client) => { + let suggestionIds: SuggestionID[] = []; + + const res = await client.query(query, [modelId, ...ids]); + const result = res.rowCount > 0; + + if (result) { + suggestionIds = res.rows + .filter((v: any) => v.suggestion_id) + .map((v: any) => new SuggestionID(v.suggestion_id)); + + await client.query(queryMitigations, [modelId, ...ids]); + this.emit("deleted-for", { + modelId: res.rows[0].model_id, + componentId: res.rows[0].component_id, + }); + } + + return [result, suggestionIds]; } - await client.query("COMMIT"); - } catch (e) { - this.log.error("Failed to delete threat", e); - await client.query("ROLLBACK"); - } finally { - client.release(); - } + ); + + // This runs in a different client and could be problematic. + const promises = suggestionIds.map((id) => + this.dal.suggestionService.setSuggestionStatus( + modelId, + id, + SuggestionStatus.New + ) + ); + + await Promise.all(promises); return result; } diff --git a/core/src/data/utils.ts b/core/src/data/utils.ts index cedd261e..75340d65 100644 --- a/core/src/data/utils.ts +++ b/core/src/data/utils.ts @@ -1,13 +1,23 @@ import pg from "pg"; import log4js from "log4js"; +import { GramConnectionPool } from "./postgres.js"; +import { DataAccessLayer } from "./dal.js"; const log = log4js.getLogger("UtilsDataService"); -export async function _deleteAllTheThings(pool: pg.Pool) { +export async function _deleteAllTheThings( + pool: pg.Pool | GramConnectionPool | DataAccessLayer +) { if (process.env.NODE_ENV !== "test") { log.warn("Attempted to _deleteAllTheThings in a non-test environment."); return; } + if (pool instanceof DataAccessLayer) { + pool = pool.pool._pool; // hehehe + } + if (pool instanceof GramConnectionPool) { + pool = pool._pool; + } await pool.query("TRUNCATE TABLE mitigations CASCADE"); await pool.query("TRUNCATE TABLE controls CASCADE"); await pool.query("TRUNCATE TABLE threats CASCADE"); diff --git a/core/src/notifications/sender.spec.ts b/core/src/notifications/sender.spec.ts index e46b8d28..57038104 100644 --- a/core/src/notifications/sender.spec.ts +++ b/core/src/notifications/sender.spec.ts @@ -13,7 +13,6 @@ const mockedSend = jest.fn(); describe("notification sender", () => { let notificationService: NotificationDataService; let templateHandler: TemplateHandler; - let pool: any; let dal: DataAccessLayer; const sampleNotification: NotificationInput = { @@ -22,7 +21,7 @@ describe("notification sender", () => { }; beforeAll(async () => { - pool = await createPostgresPool(); + const pool = await createPostgresPool(); dal = new DataAccessLayer(pool); dal.templateHandler.register( new PlaintextHandlebarsNotificationTemplate( @@ -36,7 +35,11 @@ describe("notification sender", () => { }) ) ); - notificationService = new NotificationDataService(pool, dal); + notificationService = new NotificationDataService(dal); + }); + + afterAll(async () => { + await dal.pool.end(); }); describe("notificationSender", () => { diff --git a/core/src/suggestions/engine.spec.ts b/core/src/suggestions/engine.spec.ts index bde86ac2..d8e208da 100644 --- a/core/src/suggestions/engine.spec.ts +++ b/core/src/suggestions/engine.spec.ts @@ -46,13 +46,12 @@ const ErroringSuggestionSource: SuggestionSource = { }; describe("SuggestionEngine", () => { - let pool: pg.Pool; let dal: DataAccessLayer; let modelId: string; let engine: SuggestionEngine; beforeAll(async () => { - pool = await createPostgresPool(); + const pool = await createPostgresPool(); dal = new DataAccessLayer(pool); engine = new SuggestionEngine(dal, true); }); @@ -63,7 +62,7 @@ describe("SuggestionEngine", () => { }); afterAll(async () => { - await pool.end(); + await dal.pool.end(); }); it("should handle suggestionsource errors gracefully", async () => { diff --git a/package-lock.json b/package-lock.json index 3c6afca6..1141c26f 100644 --- a/package-lock.json +++ b/package-lock.json @@ -176,7 +176,7 @@ "handlebars": "^4.7.7", "jsonwebtoken": "^9.0.0", "log4js": "^6.6.1", - "pg": "^8.11.1", + "pg": "^8.11.3", "postgres-migrations": "^5.3.0", "prom-client": "^14.0.1", "tslib": "^2.3.1" @@ -24529,13 +24529,13 @@ "dev": true }, "node_modules/pg": { - "version": "8.11.1", - "resolved": "https://registry.npmjs.org/pg/-/pg-8.11.1.tgz", - "integrity": "sha512-utdq2obft07MxaDg0zBJI+l/M3mBRfIpEN3iSemsz0G5F2/VXx+XzqF4oxrbIZXQxt2AZzIUzyVg/YM6xOP/WQ==", + "version": "8.11.3", + "resolved": "https://registry.npmjs.org/pg/-/pg-8.11.3.tgz", + "integrity": "sha512-+9iuvG8QfaaUrrph+kpF24cXkH1YOOUeArRNYIxq1viYHZagBxrTno7cecY1Fa44tJeZvaoG+Djpkc3JwehN5g==", "dependencies": { "buffer-writer": "2.0.0", "packet-reader": "1.0.0", - "pg-connection-string": "^2.6.1", + "pg-connection-string": "^2.6.2", "pg-pool": "^3.6.1", "pg-protocol": "^1.6.0", "pg-types": "^2.1.0", @@ -24563,9 +24563,9 @@ "optional": true }, "node_modules/pg-connection-string": { - "version": "2.6.1", - "resolved": "https://registry.npmjs.org/pg-connection-string/-/pg-connection-string-2.6.1.tgz", - "integrity": "sha512-w6ZzNu6oMmIzEAYVw+RLK0+nqHPt8K3ZnknKi+g48Ak2pr3dtljJW3o+D/n2zzCG07Zoe9VOX3aiKpj+BN0pjg==" + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/pg-connection-string/-/pg-connection-string-2.6.2.tgz", + "integrity": "sha512-ch6OwaeaPYcova4kKZ15sbJ2hKb/VP48ZD2gE7i1J+L4MspCtBMAx8nMgz7bksc7IojCIIWuEhHibSMFH8m8oA==" }, "node_modules/pg-int8": { "version": "1.0.1",