Skip to content

Commit

Permalink
Fix: Get metrics from _conf.metrics on Trainer to avoid issue with Fl…
Browse files Browse the repository at this point in the history
…attening the matrix (#346)

---------

Signed-off-by: ssrigiri1 <[email protected]>
Co-authored-by: Avik Basu <[email protected]>
  • Loading branch information
shashank10456 and ab93 authored Feb 8, 2024
1 parent 2d84876 commit 4dbde41
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion numalogic/udfs/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:

# Construct feature array
x_train, nan_counter, inf_counter = self.get_feature_arr(
df, payload.metrics, max_value_map=_conf.numalogic_conf.trainer.max_value_map
df, _conf.metrics, max_value_map=_conf.numalogic_conf.trainer.max_value_map
)
_add_summary(
summary=NAN_SUMMARY,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.6.2"
version = "0.6.3"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
9 changes: 8 additions & 1 deletion tests/udfs/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def test_trainer_01(self):
conf={"seq_len": 12, "n_features": 2},
),
preprocess=[ModelInfo(name="LogTransformer", stateful=True, conf={})],
trainer=TrainerConf(pltrainer_conf=LightningTrainerConf(max_epochs=1)),
trainer=TrainerConf(
pltrainer_conf=LightningTrainerConf(accelerator="cpu", max_epochs=1)
),
),
)
}
Expand Down Expand Up @@ -122,6 +124,7 @@ def test_trainer_03(self):
ml_pipelines={
"pipeline1": MLPipelineConf(
pipeline_id="pipeline1",
metrics=["failed", "degraded"],
numalogic_conf=NumalogicConf(
model=ModelInfo(
name="VanillaAE", conf={"seq_len": 12, "n_features": 2}
Expand Down Expand Up @@ -154,6 +157,7 @@ def test_trainer_do_train(self):
ml_pipelines={
"pipeline1": MLPipelineConf(
pipeline_id="pipeline1",
metrics=["failed", "degraded"],
numalogic_conf=NumalogicConf(
model=ModelInfo(
name="VanillaAE", conf={"seq_len": 12, "n_features": 2}
Expand Down Expand Up @@ -197,6 +201,7 @@ def test_trainer_do_not_train_1(self):
ml_pipelines={
"pipeline1": MLPipelineConf(
pipeline_id="pipeline1",
metrics=["failed", "degraded"],
numalogic_conf=NumalogicConf(
model=ModelInfo(
name="VanillaAE", conf={"seq_len": 12, "n_features": 2}
Expand Down Expand Up @@ -238,6 +243,7 @@ def test_trainer_do_not_train_2(self):
ml_pipelines={
"pipeline1": MLPipelineConf(
pipeline_id="pipeline1",
metrics=["failed", "degraded"],
numalogic_conf=NumalogicConf(
model=ModelInfo(
name="VanillaAE", conf={"seq_len": 12, "n_features": 2}
Expand Down Expand Up @@ -402,6 +408,7 @@ def test_trainer_datafetcher_err_and_train(self):
ml_pipelines={
"pipeline1": MLPipelineConf(
pipeline_id="pipeline1",
metrics=["failed", "degraded"],
numalogic_conf=NumalogicConf(
model=ModelInfo(
name="VanillaAE", conf={"seq_len": 12, "n_features": 2}
Expand Down

0 comments on commit 4dbde41

Please sign in to comment.