Skip to content

Commit

Permalink
Docs: fix trainer metric definitions (#1924)
Browse files Browse the repository at this point in the history
* Docs: fix trainer metric definitions

* Link to torchmetrics docs

* Teach Sphinx where docs live
  • Loading branch information
adamjstewart authored and isaaccorley committed Mar 3, 2024
1 parent dd38fdd commit b9653be
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 35 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
"sklearn": ("https://scikit-learn.org/stable/", None),
"timm": ("https://huggingface.co/docs/timm/main/en/", None),
"torch": ("https://pytorch.org/docs/stable", None),
"torchmetrics": ("https://lightning.ai/docs/torchmetrics/stable/", None),
"torchvision": ("https://pytorch.org/vision/stable", None),
}

Expand Down
30 changes: 16 additions & 14 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,15 @@ def configure_losses(self) -> None:
def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels.
Uses 'micro' averaging. Higher values are better.
* Multiclass Average Accuracy (AA): Ratio of correctly classified classes.
Uses 'macro' averaging. Higher values are better.
* Multiclass Jaccard Index (IoU): Per-class overlap between predicted and
actual classes. Uses 'macro' averaging. Higher valuers are better.
* Multiclass F1 Score: The harmonic mean of precision and recall.
Uses 'micro' averaging. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassAccuracy`: The number of
true positives divided by the dataset size. Both overall accuracy (OA)
using 'micro' averaging and average accuracy (AA) using 'macro' averaging
are reported. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection
over union (IoU). Uses 'macro' averaging. Higher valuers are better.
* :class:`~torchmetrics.classification.MulticlassFBetaScore`: F1 score.
The harmonic mean of precision and recall. Uses 'micro' averaging.
Higher values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
Expand Down Expand Up @@ -266,12 +267,13 @@ class MultiLabelClassificationTask(ClassificationTask):
def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels.
Uses 'micro' averaging. Higher values are better.
* Multiclass Average Accuracy (AA): Ratio of correctly classified classes.
Uses 'macro' averaging. Higher values are better.
* Multiclass F1 Score: The harmonic mean of precision and recall.
Uses 'micro' averaging. Higher values are better.
* :class:`~torchmetrics.classification.MultilabelAccuracy`: The number of
true positives divided by the dataset size. Both overall accuracy (OA)
using 'micro' averaging and average accuracy (AA) using 'macro' averaging
are reported. Higher values are better.
* :class:`~torchmetrics.classification.MultilabelFBetaScore`: F1 score.
The harmonic mean of precision and recall. Uses 'micro' averaging.
Higher values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not
Expand Down
11 changes: 6 additions & 5 deletions torchgeo/trainers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,19 @@ def configure_models(self) -> None:
def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* Mean Average Precision (mAP): Computes the Mean-Average-Precision (mAP) and
Mean-Average-Recall (mAR) for object detection. Prediction is based on the
intersection over union (IoU) between the predicted bounding boxes and the
ground truth bounding boxes. Uses 'macro' averaging. Higher values are better.
* :class:`~torchmetrics.detection.mean_ap.MeanAveragePrecision`: Mean average
precision (mAP) and mean average recall (mAR). Precision is the number of
true positives divided by the number of true positives + false positives.
Recall is the number of true positives divived by the number of true positives
+ false negatives. Uses 'macro' averaging. Higher values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not
reflect minority class accuracy.
* 'Macro' averaging gives equal weight to each class, and is useful for
balanced performance assessment across imbalanced classes.
"""
metrics = MetricCollection([MeanAveragePrecision()])
metrics = MetricCollection([MeanAveragePrecision(average="macro")])
self.val_metrics = metrics.clone(prefix="val_")
self.test_metrics = metrics.clone(prefix="test_")

Expand Down
18 changes: 6 additions & 12 deletions torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,12 @@ def configure_losses(self) -> None:
def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* Root Mean Squared Error (RMSE): The square root of the average of the squared
differences between the predicted and actual values. Lower values are better.
* Mean Squared Error (MSE): The average of the squared differences between the
predicted and actual values. Lower values are better.
* Mean Absolute Error (MAE): The average of the absolute differences between the
predicted and actual values. Lower values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
minority class accuracy.
* 'Macro' averaging gives equal weight to each class, and is useful for
balanced performance assessment across imbalanced classes.
* :class:`~torchmetrics.MeanSquaredError`: The average of the squared
differences between the predicted and actual values (MSE) and its
square root (RMSE). Lower values are better.
* :class:`~torchmetrics.MeanAbsoluteError`: The average of the absolute
differences between the predicted and actual values (MAE).
Lower values are better.
"""
metrics = MetricCollection(
{
Expand Down
9 changes: 5 additions & 4 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,11 @@ def configure_losses(self) -> None:
def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* Multiclass Pixel Accuracy: Ratio of correctly classified pixels.
Uses 'micro' averaging. Higher values are better.
* Multiclass Jaccard Index (IoU): Per-pixel overlap between predicted and
actual segments. Uses 'macro' averaging. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassAccuracy`: Overall accuracy
(OA) using 'micro' averaging. The number of true positives divided by the
dataset size. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection
over union (IoU). Uses 'micro' averaging. Higher valuers are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
Expand Down

0 comments on commit b9653be

Please sign in to comment.