Skip to content

Commit

Permalink
Merge branch 'main' into betavae-1
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Feb 8, 2024
2 parents dd847d5 + 4dbde41 commit 644875d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion numalogic/udfs/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:

# Construct feature array
x_train, nan_counter, inf_counter = self.get_feature_arr(
df, payload.metrics, max_value_map=_conf.numalogic_conf.trainer.max_value_map
df, _conf.metrics, max_value_map=_conf.numalogic_conf.trainer.max_value_map
)
_add_summary(
summary=NAN_SUMMARY,
Expand Down
9 changes: 8 additions & 1 deletion tests/udfs/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def test_trainer_01(self):
conf={"seq_len": 12, "n_features": 2},
),
preprocess=[ModelInfo(name="LogTransformer", stateful=True, conf={})],
trainer=TrainerConf(pltrainer_conf=LightningTrainerConf(max_epochs=1)),
trainer=TrainerConf(
pltrainer_conf=LightningTrainerConf(accelerator="cpu", max_epochs=1)
),
),
)
}
Expand Down Expand Up @@ -122,6 +124,7 @@ def test_trainer_03(self):
ml_pipelines={
"pipeline1": MLPipelineConf(
pipeline_id="pipeline1",
metrics=["failed", "degraded"],
numalogic_conf=NumalogicConf(
model=ModelInfo(
name="VanillaAE", conf={"seq_len": 12, "n_features": 2}
Expand Down Expand Up @@ -154,6 +157,7 @@ def test_trainer_do_train(self):
ml_pipelines={
"pipeline1": MLPipelineConf(
pipeline_id="pipeline1",
metrics=["failed", "degraded"],
numalogic_conf=NumalogicConf(
model=ModelInfo(
name="VanillaAE", conf={"seq_len": 12, "n_features": 2}
Expand Down Expand Up @@ -197,6 +201,7 @@ def test_trainer_do_not_train_1(self):
ml_pipelines={
"pipeline1": MLPipelineConf(
pipeline_id="pipeline1",
metrics=["failed", "degraded"],
numalogic_conf=NumalogicConf(
model=ModelInfo(
name="VanillaAE", conf={"seq_len": 12, "n_features": 2}
Expand Down Expand Up @@ -238,6 +243,7 @@ def test_trainer_do_not_train_2(self):
ml_pipelines={
"pipeline1": MLPipelineConf(
pipeline_id="pipeline1",
metrics=["failed", "degraded"],
numalogic_conf=NumalogicConf(
model=ModelInfo(
name="VanillaAE", conf={"seq_len": 12, "n_features": 2}
Expand Down Expand Up @@ -402,6 +408,7 @@ def test_trainer_datafetcher_err_and_train(self):
ml_pipelines={
"pipeline1": MLPipelineConf(
pipeline_id="pipeline1",
metrics=["failed", "degraded"],
numalogic_conf=NumalogicConf(
model=ModelInfo(
name="VanillaAE", conf={"seq_len": 12, "n_features": 2}
Expand Down

0 comments on commit 644875d

Please sign in to comment.