diff --git a/CHANGELOG.md b/CHANGELOG.md index 5700a43a98b..40c5082370a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed for Pearson changes inputs ([#2765](https://github.com/Lightning-AI/torchmetrics/pull/2765)) +- Fixed bug in `PESQ` metric where `NoUtterancesError` prevented calculating on a batch of data ([#2753](https://github.com/Lightning-AI/torchmetrics/pull/2753)) + + ## [1.4.2] - 2022-09-12 ### Added diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index 09ceafb331f..a7678f3e2a5 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -108,7 +108,7 @@ also manually log the output of the metrics. class MyModule(LightningModule): - def __init__(self): + def __init__(self, num_classes): ... self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes) self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes) @@ -157,6 +157,43 @@ Additionally, we highly recommend that the two ways of logging are not mixed as self.valid_acc.update(logits, y) self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True) +In general if you are logging multiple metrics we highly recommend that you combine them into a single metric object +using the :class:`~torchmetrics.MetricCollection` class and then replacing the ``self.log`` calls with ``self.log_dict``, +assuming that all metrics receive the same input. + +.. testcode:: python + + class MyModule(LightningModule): + + def __init__(self): + ... + self.train_metrics = torchmetrics.MetricCollection( + { + "accuracy": torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes), + "f1": torchmetrics.classification.F1(task="multiclass", num_classes=num_classes), + }, + prefix="train_", + ) + self.valid_metrics = self.train_metrics.clone(prefix="valid_") + + def training_step(self, batch, batch_idx): + x, y = batch + preds = self(x) + ... + batch_value = self.train_metrics(preds, y) + self.log_dict(batch_value) + + def on_train_epoch_end(self): + self.train_metrics.reset() + + def validation_step(self, batch, batch_idx): + logits = self(x) + ... + self.valid_metrics.update(logits, y) + + def on_validation_epoch_end(self, outputs): + self.log_dict(self.valid_metrics.compute()) + self.valid_metrics.reset() *************** Common Pitfalls @@ -164,6 +201,41 @@ Common Pitfalls The following contains a list of pitfalls to be aware of: +* Logging a `MetricCollection` object directly using ``self.log_dict`` is only supported if all metrics in the + collection returns a scalar tensor. If any of the metrics in the collection returns a non-scalar tensor, + the logging will fail. This can especially happen when either nesting multiple ``MetricCollection`` objects or when + using wrapper metrics such as :class:`~torchmetrics.wrappers.ClasswiseWrapper`, + :class:`~torchmetrics.wrappers.MinMaxMetric` etc. inside a ``MetricCollection`` since all these wrappers return + dicts or lists of tensors. It is still possible to log such nested metrics manually because the ``MetricCollection`` + object will try to flatten everything into a single dict. Example: + +.. testcode:: python + + class MyModule(LightningModule): + + def __init__(self): + super().__init__() + self.train_metrics = MetricCollection( + { + "macro_accuracy": MinMaxMetric(MulticlassAccuracy(num_classes=5, average="macro")), + "weighted_accuracy": MinMaxMetric(MulticlassAccuracy(num_classes=5, average="weighted")), + }, + prefix="train_", + ) + + def training_step(self, batch, batch_idx): + ... + # logging the MetricCollection object directly will fail + self.log_dict(self.train_metrics(preds, target)) + + # manually computing the result and then logging will work + batch_values = self.train_metrics(preds, target) + self.log_dict(batch_values, on_step=True, on_epoch=False) + ... + + def on_train_epoch_end(self): + self.train_metrics.reset() + * Modular metrics contain internal states that should belong to only one DataLoader. In case you are using multiple DataLoaders, it is recommended to initialize a separate modular metric instances for each DataLoader and use them separately. The same holds for using separate metrics for training, validation and testing. diff --git a/src/torchmetrics/functional/audio/pesq.py b/src/torchmetrics/functional/audio/pesq.py index ad1d9506f08..516014434bf 100644 --- a/src/torchmetrics/functional/audio/pesq.py +++ b/src/torchmetrics/functional/audio/pesq.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import numpy as np import torch from torch import Tensor @@ -83,6 +85,11 @@ def perceptual_evaluation_speech_quality( ) import pesq as pesq_backend + def _issubtype_number(x: Any) -> bool: + return np.issubdtype(type(x), np.number) + + _filter_error_msg = np.vectorize(_issubtype_number) + if fs not in (8000, 16000): raise ValueError(f"Expected argument `fs` to either be 8000 or 16000 but got {fs}") if mode not in ("wb", "nb"): @@ -103,8 +110,8 @@ def perceptual_evaluation_speech_quality( pesq_val_np = np.empty(shape=(preds_np.shape[0])) for b in range(preds_np.shape[0]): pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode) - pesq_val = torch.from_numpy(pesq_val_np) - pesq_val = pesq_val.reshape(preds.shape[:-1]) + pesq_val = torch.from_numpy(pesq_val_np[_filter_error_msg(pesq_val_np)].astype(np.float32)) + pesq_val = pesq_val.reshape(len(pesq_val)) if keep_same_device: return pesq_val.to(preds.device) diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index 513c351e2a0..302a4353f76 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -30,7 +30,7 @@ from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision, MulticlassAccuracy from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError from torchmetrics.utilities.prints import rank_zero_only -from torchmetrics.wrappers import ClasswiseWrapper, MultitaskWrapper +from torchmetrics.wrappers import ClasswiseWrapper, MinMaxMetric, MultitaskWrapper from integrations.lightning.boring_model import BoringModel @@ -523,35 +523,26 @@ def __init__(self) -> None: }, prefix="train_", ) - - self.val_metrics = MetricCollection( - { - "macro_accuracy": MulticlassAccuracy(num_classes=5, average="macro"), - "classwise_accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=5, average=None)), - }, - prefix="val_", - ) + self.val_metrics = self.train_metrics.clone(prefix="val_") def training_step(self, batch, batch_idx): loss = self(batch).sum() - preds = torch.randint(0, 5, (100,), device=batch.device) target = torch.randint(0, 5, (100,), device=batch.device) self.train_metrics.update(preds, target) batch_values = self.train_metrics.compute() self.log_dict(batch_values, on_step=True, on_epoch=False) - return {"loss": loss} def validation_step(self, batch, batch_idx): preds = torch.randint(0, 5, (100,), device=batch.device) target = torch.randint(0, 5, (100,), device=batch.device) - self.val_metrics.update(preds, target) def on_validation_epoch_end(self): self.log_dict(self.val_metrics.compute(), on_step=False, on_epoch=True) + self.val_metrics.reset() model = TestModel() @@ -559,7 +550,7 @@ def on_validation_epoch_end(self): default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, - max_epochs=1, + max_epochs=2, log_every_n_steps=1, ) trainer.fit(model) @@ -572,3 +563,61 @@ def on_validation_epoch_end(self): for i in range(5): assert f"train_multiclassaccuracy_{i}" in logged assert f"val_multiclassaccuracy_{i}" in logged + + +def test_collection_minmax_lightning_integration(tmpdir): + """Check the integration of MinMaxWrapper, MetricCollection and LightningModule. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2763 + + """ + + class TestModel(BoringModel): + def __init__(self) -> None: + super().__init__() + self.train_metrics = MetricCollection( + { + "macro_accuracy": MinMaxMetric(MulticlassAccuracy(num_classes=5, average="macro")), + "weighted_accuracy": MinMaxMetric(MulticlassAccuracy(num_classes=5, average="weighted")), + }, + prefix="train_", + ) + self.val_metrics = self.train_metrics.clone(prefix="val_") + + def training_step(self, batch, batch_idx): + loss = self(batch).sum() + preds = torch.randint(0, 5, (100,), device=batch.device) + target = torch.randint(0, 5, (100,), device=batch.device) + + self.train_metrics.update(preds, target) + batch_values = self.train_metrics.compute() + self.log_dict(batch_values, on_step=True, on_epoch=False) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + preds = torch.randint(0, 5, (100,), device=batch.device) + target = torch.randint(0, 5, (100,), device=batch.device) + self.val_metrics.update(preds, target) + + def on_validation_epoch_end(self): + self.log_dict(self.val_metrics.compute(), on_step=False, on_epoch=True) + self.val_metrics.reset() + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + log_every_n_steps=1, + ) + trainer.fit(model) + + logged = trainer.logged_metrics + + # check that all metrics are logged + for prefix in ["train_", "val_"]: + for metric in ["macro_accuracy", "weighted_accuracy"]: + for key in ["max", "min", "raw"]: + assert f"{prefix}{metric}_{key}" in logged