Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SemanticSegmentationTask: add class-wise metrics #2130

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
23fa1fb
Add average metrics
robmarkcole Jun 19, 2024
b7d8305
Add average metrics
robmarkcole Jun 19, 2024
b1526fa
refactor: Rename metrics in SemanticSegmentationTask
robmarkcole Jun 19, 2024
341e272
Ruff format
robmarkcole Jun 19, 2024
024feda
Use ignore_index
robmarkcole Jun 20, 2024
04cac59
pass on_epoch
robmarkcole Jun 20, 2024
56f20fc
on_epoch to train too
robmarkcole Jun 20, 2024
3d2b309
Disable on_step for train metrics
robmarkcole Jun 20, 2024
9af1493
Merge branch 'main' into update-metrics
robmarkcole Jun 20, 2024
192c496
ruff format
robmarkcole Jun 20, 2024
73b710f
Merge branch 'main' into update-metrics
robmarkcole Jun 21, 2024
8ce8c30
Merge branch 'main' into update-metrics
robmarkcole Jun 23, 2024
e4ed9fd
Merge branch 'main' into update-metrics
robmarkcole Jul 2, 2024
d9c2688
Merge branch 'main' into update-metrics
robmarkcole Jul 8, 2024
400fae3
Merge branch 'main' into update-metrics
robmarkcole Jul 21, 2024
f4c793e
Merge branch 'main' into update-metrics
robmarkcole Aug 1, 2024
3b629ea
Merge branch 'main' into update-metrics
robmarkcole Aug 5, 2024
e6abadd
Merge branch 'main' into update-metrics
robmarkcole Aug 6, 2024
da887fe
Bump min torchmetrics
robmarkcole Aug 6, 2024
5138ccb
Merge branch 'update-metrics' of https://github.com/robmarkcole/torch…
robmarkcole Aug 6, 2024
479c7e3
Raise torchmetrics min
robmarkcole Aug 6, 2024
50b7d29
remo on_epoch etc
robmarkcole Aug 7, 2024
1cd436f
Remove on_epoch
robmarkcole Aug 7, 2024
9e985e2
try torchmetrics==1.1.0
robmarkcole Aug 7, 2024
c773322
try torchmetrics==1.1.1
robmarkcole Aug 7, 2024
9d8c8e4
Merge branch 'main' into update-metrics
robmarkcole Aug 7, 2024
e2640f5
Use loop to generate metrics
robmarkcole Aug 8, 2024
19187a9
Update
robmarkcole Aug 8, 2024
a3f7ffe
Fix jaccard
robmarkcole Aug 8, 2024
9a66442
fix dependencies delta
robmarkcole Aug 8, 2024
8381cb7
fix pyproject
robmarkcole Aug 8, 2024
b5050ad
Merge branch 'main' into update-metrics
robmarkcole Sep 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ repos:
- pyvista>=0.34.2
- scikit-image>=0.22.0
- torch>=2.3
- torchmetrics>=0.10
- torchmetrics>=1.1.1
- torchvision>=0.18
exclude: (build|data|dist|logo|logs|output)/
- repo: https://github.com/pre-commit/mirrors-prettier
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ dependencies = [
"timm>=0.4.12",
# torch 1.13+ required by torchvision
"torch>=1.13",
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics
"torchmetrics>=0.10",
# torchmetrics 1.1.1+ required for average argument to MeanAveragePrecision
"torchmetrics>=1.1.1",
# torchvision 0.14+ required for torchvision.models.swin_v2_b
"torchvision>=0.14",
]
Expand Down
2 changes: 1 addition & 1 deletion requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ segmentation-models-pytorch==0.2.0
shapely==1.8.0
timm==0.4.12
torch==1.13.0
torchmetrics==0.10.0
torchmetrics==1.1.1
torchvision==0.14.0

# datasets
Expand Down
103 changes: 82 additions & 21 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
from matplotlib.figure import Figure
from torch import Tensor
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
from torchmetrics.classification import (
Accuracy,
FBetaScore,
JaccardIndex,
Precision,
Recall,
)
from torchmetrics.wrappers import ClasswiseWrapper
from torchvision.models._api import WeightsEnum

from ..datasets import RGBBandsMissingError, unbind_samples
Expand All @@ -31,6 +38,7 @@ def __init__(
weights: WeightsEnum | str | bool | None = None,
in_channels: int = 3,
num_classes: int = 1000,
labels: list[str] | None = None,
num_filters: int = 3,
loss: str = 'ce',
class_weights: Tensor | None = None,
Expand All @@ -55,6 +63,7 @@ def __init__(
are not supported yet.
in_channels: Number of input channels to model.
num_classes: Number of prediction classes (including the background).
labels: List of class labels.
num_filters: Number of filters. Only applicable when model='fcn'.
loss: Name of the loss function, currently supports
'ce', 'jaccard' or 'focal' loss.
Expand Down Expand Up @@ -190,27 +199,70 @@ def configure_metrics(self) -> None:
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
minority class accuracy.
* 'Macro' averaging, not used here, gives equal weight to each class, useful
* 'Macro' averaging gives equal weight to each class, useful
for balanced performance assessment across imbalanced classes.
"""
num_classes: int = self.hparams['num_classes']
ignore_index: int | None = self.hparams['ignore_index']
metrics = MetricCollection(
[
MulticlassAccuracy(
num_classes=num_classes,
ignore_index=ignore_index,
multidim_average='global',
average='micro',
),
MulticlassJaccardIndex(
num_classes=num_classes, ignore_index=ignore_index, average='micro'
),
]
)
self.train_metrics = metrics.clone(prefix='train_')
self.val_metrics = metrics.clone(prefix='val_')
self.test_metrics = metrics.clone(prefix='test_')
labels: list[str] | None = self.hparams['labels']

metric_classes = {
'Accuracy': Accuracy,
'F1Score': FBetaScore,
'JaccardIndex': JaccardIndex,
'Precision': Precision,
'Recall': Recall,
}

metrics_dict = {}

# Loop through the types of averaging
for average in ['micro', 'macro']:
for metric_name, metric_class in metric_classes.items():
name = (
f'Overall{metric_name}'
if average == 'micro'
else f'Average{metric_name}'
)
params = {
'task': 'multiclass',
'num_classes': num_classes,
'average': average,
'ignore_index': ignore_index,
}
if metric_name in ['Accuracy', 'F1Score', 'Precision', 'Recall']:
params['multidim_average'] = 'global'
if metric_name == 'F1Score':
params['beta'] = 1.0
metrics_dict[name] = metric_class(**params)

# Loop through the classwise metrics
for metric_name, metric_class in metric_classes.items():
if metric_name != 'JaccardIndex':
metrics_dict[metric_name] = ClasswiseWrapper(
metric_class(
task='multiclass',
num_classes=num_classes,
average='none',
multidim_average='global',
ignore_index=ignore_index,
),
labels=labels,
)
else:
metrics_dict[metric_name] = ClasswiseWrapper(
metric_class(
task='multiclass',
num_classes=num_classes,
average='none',
ignore_index=ignore_index,
),
labels=labels,
)

self.train_metrics = MetricCollection(metrics_dict, prefix='train_')
self.val_metrics = self.train_metrics.clone(prefix='val_')
self.test_metrics = self.train_metrics.clone(prefix='test_')

def training_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
Expand All @@ -232,7 +284,10 @@ def training_step(
loss: Tensor = self.criterion(y_hat, y)
self.log('train_loss', loss, batch_size=batch_size)
self.train_metrics(y_hat, y)
self.log_dict(self.train_metrics, batch_size=batch_size)
self.log_dict(
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
{f'{k}': v for k, v in self.train_metrics.compute().items()},
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
batch_size=batch_size,
)
return loss

def validation_step(
Expand All @@ -252,7 +307,10 @@ def validation_step(
loss = self.criterion(y_hat, y)
self.log('val_loss', loss, batch_size=batch_size)
self.val_metrics(y_hat, y)
self.log_dict(self.val_metrics, batch_size=batch_size)
self.log_dict(
{f'{k}': v for k, v in self.val_metrics.compute().items()},
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
batch_size=batch_size,
)

if (
batch_idx < 10
Expand Down Expand Up @@ -296,7 +354,10 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
loss = self.criterion(y_hat, y)
self.log('test_loss', loss, batch_size=batch_size)
self.test_metrics(y_hat, y)
self.log_dict(self.test_metrics, batch_size=batch_size)
self.log_dict(
{f'{k}': v for k, v in self.test_metrics.compute().items()},
batch_size=batch_size,
)

def predict_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
Expand Down