Skip to content

Commit

Permalink
Integration test for MetricCollection, ClasswiseWrapper and `self…
Browse files Browse the repository at this point in the history
….log_dict` method (#2720)

* integration test
* batch_idx missing
  • Loading branch information
SkafteNicki authored Sep 4, 2024
1 parent a31417c commit 1ba2c85
Showing 1 changed file with 70 additions and 2 deletions.
72 changes: 70 additions & 2 deletions tests/integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@

from torchmetrics import MetricCollection
from torchmetrics.aggregation import SumMetric
from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision
from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision, MulticlassAccuracy
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
from torchmetrics.utilities.prints import rank_zero_only
from torchmetrics.wrappers import MultitaskWrapper
from torchmetrics.wrappers import ClasswiseWrapper, MultitaskWrapper

from integrations.lightning.boring_model import BoringModel

Expand Down Expand Up @@ -504,3 +504,71 @@ def configure_optimizers(self):

model = model.type(torch.half)
assert model.metric.sum_value.dtype == torch.float32


def test_collection_classwise_lightning_integration(tmpdir):
"""Check the integration of ClasswiseWrapper, MetricCollection and LightningModule.
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2683
"""

class TestModel(BoringModel):
def __init__(self) -> None:
super().__init__()
self.train_metrics = MetricCollection(
{
"macro_accuracy": MulticlassAccuracy(num_classes=5, average="macro"),
"classwise_accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=5, average=None)),
},
prefix="train_",
)

self.val_metrics = MetricCollection(
{
"macro_accuracy": MulticlassAccuracy(num_classes=5, average="macro"),
"classwise_accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=5, average=None)),
},
prefix="val_",
)

def training_step(self, batch, batch_idx):
loss = self(batch).sum()

preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)

self.train_metrics.update(preds, target)
batch_values = self.train_metrics.compute()
self.log_dict(batch_values, on_step=True, on_epoch=False)

return {"loss": loss}

def validation_step(self, batch, batch_idx):
preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)

self.val_metrics.update(preds, target)

def on_validation_epoch_end(self):
self.log_dict(self.val_metrics.compute(), on_step=False, on_epoch=True)

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
)
trainer.fit(model)

logged = trainer.logged_metrics

# check that all metrics are logged
assert "train_macro_accuracy" in logged
assert "val_macro_accuracy" in logged
for i in range(5):
assert f"train_multiclassaccuracy_{i}" in logged
assert f"val_multiclassaccuracy_{i}" in logged

0 comments on commit 1ba2c85

Please sign in to comment.