Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewpeng02 committed May 10, 2024
1 parent d9c11d8 commit e7ce7d6
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 119 deletions.
18 changes: 18 additions & 0 deletions dlp-terraform/ecs/sqs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,24 @@ resource "aws_sqs_queue" "training_queue" {
name = "training-queue.fifo"
fifo_queue = true
message_retention_seconds = 60*24

redrive_policy = jsonencode({
deadLetterTargetArn = aws_sqs_queue.training_queue_deadletter.arn
maxReceiveCount = 4
})
}

resource "aws_sqs_queue" "training_queue_deadletter" {
name = "training-deadletter-queue"
}

resource "aws_sqs_queue_redrive_allow_policy" "training_queue_redrive_allow_policy" {
queue_url = aws_sqs_queue.training_queue_deadletter.id

redrive_allow_policy = jsonencode({
redrivePermission = "byQueue",
sourceQueueArns = [aws_sqs_queue.training_queue.arn]
})
}

output "sqs_queue_url" {
Expand Down
127 changes: 9 additions & 118 deletions frontend/src/pages/train/[train_space_id].tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import { DetailedTrainResultsData } from "@/features/Train/types/trainTypes";
import Container from "@mui/material/Container";
import Grid from "@mui/material/Grid";
import Paper from "@mui/material/Paper";
import dynamic from "next/dynamic";
import { useRouter } from "next/router";
import { Data, XAxisName, YAxisName } from "plotly.js";
import React, { useEffect } from "react";
const Plot = dynamic(() => import("react-plotly.js"), { ssr: false });

const LINE_CHART_COLORS = ["red", "blue", "green"];
import {
mapMetricToLinePlot,
mapMetricToAucRocPlot,
mapMetricToConfusionMatrixPlot,
} from "./metrics_to_charts";

const mapTrainResultsDataToCharts = (
detailedTrainResultsData: DetailedTrainResultsData
Expand All @@ -27,120 +27,11 @@ const mapTrainResultsDataToCharts = (
while (i < sortedData.length) {
const metric = sortedData[i];
if (metric.chart_type === "LINE") {
const data = [];
for (let i = 0; i < metric.time_series.length; i++) {
const time_series = metric.time_series[i];
data.push({
name: time_series.y_name,
x: time_series.x_values,
y: time_series.y_values,
type: "scatter",
mode: "markers",
marker: { color: LINE_CHART_COLORS[i], size: 10 },
});
}
charts.push(
<Plot
data={data as Data[]}
layout={{
height: 350,
width: 525,
xaxis: { title: metric.time_series[0].x_name },
// yaxis: { title: "Y axis" },
title: metric.name,
showlegend: true,
paper_bgcolor: "rgba(0,0,0,0)",
plot_bgcolor: "rgba(0,0,0,0)",
}}
config={{ responsive: true }}
/>
);
charts.push(mapMetricToLinePlot(metric));
} else if (metric.chart_type === "AUC/ROC") {
charts.push(
<Plot
data={[
{
name: "baseline",
x: [0, 1],
y: [0, 1],
type: "scatter",
marker: { color: "grey" },
line: {
dash: "dash",
},
},
...(metric.values.map((x) => ({
name: `(AUC: ${x[2]})`,
x: x[0] as number[],
y: x[1] as number[],
type: "scatter",
})) as Data[]),
]}
layout={{
height: 350,
width: 525,
xaxis: { title: "False Positive Rate" },
yaxis: { title: "True Positive Rate" },
title: "AUC/ROC Curves for your Deep Learning Model",
showlegend: true,
paper_bgcolor: "rgba(0,0,0,0)",
plot_bgcolor: "rgba(0,0,0,0)",
}}
config={{ responsive: true }}
/>
);
charts.push(mapMetricToAucRocPlot(metric));
} else if (metric.chart_type === "CONFUSION_MATRIX") {
charts.push(
<Plot
data={[
{
z: metric.values,
type: "heatmap",
colorscale: [
[0, "#e6f6fe"],
[1, "#003058"],
],
},
]}
layout={{
height: 525,
width: 525,
title: "Confusion Matrix (Last Epoch)",
xaxis: {
title: "Predicted",
},
yaxis: {
title: "Actual",
autorange: "reversed",
},
showlegend: true,
annotations: metric.values
.map((row, i) =>
row.map((_, j) => ({
xref: "x1" as XAxisName,
yref: "y1" as YAxisName,
x: j,
y: (i + metric.values.length - 1) % metric.values.length,
text: metric.values[
(i + metric.values.length - 1) % metric.values.length
][j].toString(),
font: {
color:
metric.values[
(i + metric.values.length - 1) % metric.values.length
][j] > 0
? "white"
: "black",
},
showarrow: false,
}))
)
.flat(),
paper_bgcolor: "rgba(0,0,0,0)",
plot_bgcolor: "rgba(0,0,0,0)",
}}
/>
);
charts.push(mapMetricToConfusionMatrixPlot(metric));
} else {
throw Error("Undefined chart type received");
}
Expand All @@ -163,7 +54,7 @@ const TrainSpace = () => {
router.replace({ pathname: "/login" });
}
}, [user, router.isReady]);

if (error) {
setTimeout(() => refetch(), 3000);
}
Expand Down
135 changes: 135 additions & 0 deletions frontend/src/pages/train/metrics_to_charts.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import {
AucRocChart,
ConfusionMatrixChart,
TimeSeriesChart,
} from "@/features/Train/types/trainTypes";
import dynamic from "next/dynamic";
import { Data, XAxisName, YAxisName } from "plotly.js";
const Plot = dynamic(() => import("react-plotly.js"), { ssr: false });

const LINE_CHART_COLORS = ["red", "blue", "green"];

const mapMetricToLinePlot = (metric: TimeSeriesChart) => {
const data = [];
for (let i = 0; i < metric.time_series.length; i++) {
const time_series = metric.time_series[i];
data.push({
name: time_series.y_name,
x: time_series.x_values,
y: time_series.y_values,
type: "scatter",
mode: "markers",
marker: { color: LINE_CHART_COLORS[i], size: 10 },
});
}
return (
<Plot
data={data as Data[]}
layout={{
height: 350,
width: 525,
xaxis: { title: metric.time_series[0].x_name },
// yaxis: { title: "Y axis" },
title: metric.name,
showlegend: true,
paper_bgcolor: "rgba(0,0,0,0)",
plot_bgcolor: "rgba(0,0,0,0)",
}}
config={{ responsive: true }}
/>
);
};

const mapMetricToAucRocPlot = (metric: AucRocChart) => {
return (
<Plot
data={[
{
name: "baseline",
x: [0, 1],
y: [0, 1],
type: "scatter",
marker: { color: "grey" },
line: {
dash: "dash",
},
},
...(metric.values.map((x) => ({
name: `(AUC: ${x[2]})`,
x: x[0] as number[],
y: x[1] as number[],
type: "scatter",
})) as Data[]),
]}
layout={{
height: 350,
width: 525,
xaxis: { title: "False Positive Rate" },
yaxis: { title: "True Positive Rate" },
title: "AUC/ROC Curves for your Deep Learning Model",
showlegend: true,
paper_bgcolor: "rgba(0,0,0,0)",
plot_bgcolor: "rgba(0,0,0,0)",
}}
config={{ responsive: true }}
/>
);
};

const mapMetricToConfusionMatrixPlot = (metric: ConfusionMatrixChart) => {
<Plot
data={[
{
z: metric.values,
type: "heatmap",
colorscale: [
[0, "#e6f6fe"],
[1, "#003058"],
],
},
]}
layout={{
height: 525,
width: 525,
title: "Confusion Matrix (Last Epoch)",
xaxis: {
title: "Predicted",
},
yaxis: {
title: "Actual",
autorange: "reversed",
},
showlegend: true,
annotations: metric.values
.map((row, i) =>
row.map((_, j) => ({
xref: "x1" as XAxisName,
yref: "y1" as YAxisName,
x: j,
y: (i + metric.values.length - 1) % metric.values.length,
text: metric.values[
(i + metric.values.length - 1) % metric.values.length
][j].toString(),
font: {
color:
metric.values[
(i + metric.values.length - 1) % metric.values.length
][j] > 0
? "white"
: "black",
},
showarrow: false,
}))
)
.flat(),
paper_bgcolor: "rgba(0,0,0,0)",
plot_bgcolor: "rgba(0,0,0,0)",
}}
/>;
};

export {
mapMetricToLinePlot,
mapMetricToAucRocPlot,
mapMetricToConfusionMatrixPlot,
};
1 change: 1 addition & 0 deletions training/training/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DLP_EXECUTIONS_BUCKET_NAME = "dlp-executions"
3 changes: 2 additions & 1 deletion training/training/core/celery/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import boto3


from training.constants import DLP_EXECUTIONS_BUCKET_NAME
from training.core.celery.criterion import getCriterionHandler
from training.core.celery.dataset import SklearnDatasetCreator
from training.core.celery.dataset import ImageDefaultDatasetCreator
Expand All @@ -34,7 +35,7 @@ def saveDetailedTrainResultsDataToS3(
):
s3 = boto3.resource("s3")
s3.Object(
"dlp-executions", f"{detailedTrainResultsData.basic_info.trainspaceId}.json"
DLP_EXECUTIONS_BUCKET_NAME, f"{detailedTrainResultsData.basic_info.trainspaceId}.json"
).put(Body=detailedTrainResultsData.json())


Expand Down

0 comments on commit e7ce7d6

Please sign in to comment.