diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 05dba0da7c..d7cc4b1733 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: - pyvista>=0.34.2 - scikit-image>=0.22.0 - torch>=2.3 - - torchmetrics>=0.10 + - torchmetrics>=1.1.1 - torchvision>=0.18 exclude: (build|data|dist|logo|logs|output)/ - repo: https://github.com/pre-commit/mirrors-prettier diff --git a/pyproject.toml b/pyproject.toml index 3a2c2319ee..a7c354d36c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,8 +72,8 @@ dependencies = [ "timm>=0.4.12", # torch 1.13+ required by torchvision "torch>=1.13", - # torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics - "torchmetrics>=0.10", + # torchmetrics 1.1.1+ required for average argument to MeanAveragePrecision + "torchmetrics>=1.1.1", # torchvision 0.14+ required for torchvision.models.swin_v2_b "torchvision>=0.14", ] diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index a6e91f70fe..cded7e1f9f 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -18,7 +18,7 @@ segmentation-models-pytorch==0.2.0 shapely==1.8.0 timm==0.4.12 torch==1.13.0 -torchmetrics==0.10.0 +torchmetrics==1.1.1 torchvision==0.14.0 # datasets diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 9eccda2f0f..72bb88a709 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -12,7 +12,14 @@ from matplotlib.figure import Figure from torch import Tensor from torchmetrics import MetricCollection -from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex +from torchmetrics.classification import ( + Accuracy, + FBetaScore, + JaccardIndex, + Precision, + Recall, +) +from torchmetrics.wrappers import ClasswiseWrapper from torchvision.models._api import WeightsEnum from ..datasets import RGBBandsMissingError, unbind_samples @@ -31,6 +38,7 @@ def __init__( weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, num_classes: int = 1000, + labels: list[str] | None = None, num_filters: int = 3, loss: str = 'ce', class_weights: Tensor | None = None, @@ -55,6 +63,7 @@ def __init__( are not supported yet. in_channels: Number of input channels to model. num_classes: Number of prediction classes (including the background). + labels: List of class labels. num_filters: Number of filters. Only applicable when model='fcn'. loss: Name of the loss function, currently supports 'ce', 'jaccard' or 'focal' loss. @@ -190,27 +199,70 @@ def configure_metrics(self) -> None: .. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect minority class accuracy. - * 'Macro' averaging, not used here, gives equal weight to each class, useful + * 'Macro' averaging gives equal weight to each class, useful for balanced performance assessment across imbalanced classes. """ num_classes: int = self.hparams['num_classes'] ignore_index: int | None = self.hparams['ignore_index'] - metrics = MetricCollection( - [ - MulticlassAccuracy( - num_classes=num_classes, - ignore_index=ignore_index, - multidim_average='global', - average='micro', - ), - MulticlassJaccardIndex( - num_classes=num_classes, ignore_index=ignore_index, average='micro' - ), - ] - ) - self.train_metrics = metrics.clone(prefix='train_') - self.val_metrics = metrics.clone(prefix='val_') - self.test_metrics = metrics.clone(prefix='test_') + labels: list[str] | None = self.hparams['labels'] + + metric_classes = { + 'Accuracy': Accuracy, + 'F1Score': FBetaScore, + 'JaccardIndex': JaccardIndex, + 'Precision': Precision, + 'Recall': Recall, + } + + metrics_dict = {} + + # Loop through the types of averaging + for average in ['micro', 'macro']: + for metric_name, metric_class in metric_classes.items(): + name = ( + f'Overall{metric_name}' + if average == 'micro' + else f'Average{metric_name}' + ) + params = { + 'task': 'multiclass', + 'num_classes': num_classes, + 'average': average, + 'ignore_index': ignore_index, + } + if metric_name in ['Accuracy', 'F1Score', 'Precision', 'Recall']: + params['multidim_average'] = 'global' + if metric_name == 'F1Score': + params['beta'] = 1.0 + metrics_dict[name] = metric_class(**params) + + # Loop through the classwise metrics + for metric_name, metric_class in metric_classes.items(): + if metric_name != 'JaccardIndex': + metrics_dict[metric_name] = ClasswiseWrapper( + metric_class( + task='multiclass', + num_classes=num_classes, + average='none', + multidim_average='global', + ignore_index=ignore_index, + ), + labels=labels, + ) + else: + metrics_dict[metric_name] = ClasswiseWrapper( + metric_class( + task='multiclass', + num_classes=num_classes, + average='none', + ignore_index=ignore_index, + ), + labels=labels, + ) + + self.train_metrics = MetricCollection(metrics_dict, prefix='train_') + self.val_metrics = self.train_metrics.clone(prefix='val_') + self.test_metrics = self.train_metrics.clone(prefix='test_') def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 @@ -232,7 +284,10 @@ def training_step( loss: Tensor = self.criterion(y_hat, y) self.log('train_loss', loss, batch_size=batch_size) self.train_metrics(y_hat, y) - self.log_dict(self.train_metrics, batch_size=batch_size) + self.log_dict( + {f'{k}': v for k, v in self.train_metrics.compute().items()}, + batch_size=batch_size, + ) return loss def validation_step( @@ -252,7 +307,10 @@ def validation_step( loss = self.criterion(y_hat, y) self.log('val_loss', loss, batch_size=batch_size) self.val_metrics(y_hat, y) - self.log_dict(self.val_metrics, batch_size=batch_size) + self.log_dict( + {f'{k}': v for k, v in self.val_metrics.compute().items()}, + batch_size=batch_size, + ) if ( batch_idx < 10 @@ -296,7 +354,10 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None loss = self.criterion(y_hat, y) self.log('test_loss', loss, batch_size=batch_size) self.test_metrics(y_hat, y) - self.log_dict(self.test_metrics, batch_size=batch_size) + self.log_dict( + {f'{k}': v for k, v in self.test_metrics.compute().items()}, + batch_size=batch_size, + ) def predict_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0