Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SemanticSegmentationTask: add class-wise metrics #2130

Open
wants to merge 32 commits into
base: main
Choose a base branch
from

Conversation

robmarkcole
Copy link
Contributor

@robmarkcole robmarkcole commented Jun 19, 2024

Addresses #2121 for segmentation. Mostly copied from @isaaccorley as here - he is additionally passing on_epoch=True which is NOT adopted here

Output metrics for ChaBud binary task with labels=['background', 'burned_area']
This dataset nicely illustrates why class labels are required - burned_area is minority class and is not learnt

[{'test_loss': 450.0233459472656,
  'test_multiclassaccuracy_background': 0.9817732572555542,
  'test_multiclassaccuracy_burned_area': 0.006427088752388954,
  'test_AverageAccuracy': 0.4941001236438751,
  'test_AverageF1Score': 0.4793838560581207,
  'test_AverageJaccardIndex': 0.4542679488658905,
  'test_multiclassfbetascore_background': 0.9489027857780457,
  'test_multiclassfbetascore_burned_area': 0.009864915162324905,
  'test_multiclassjaccardindex_background': 0.9035778641700745,
  'test_multiclassjaccardindex_burned_area': 0.004958001431077719,
  'test_OverallAccuracy': 0.9036323428153992,
  'test_OverallF1Score': 0.9036323428153992,
  'test_OverallJaccardIndex': 0.8265058994293213,
  'test_multiclassprecision_background': 0.9189077615737915,
  'test_multiclassprecision_burned_area': 0.0312582366168499,
  'test_multiclassrecall_background': 0.9817732572555542,
  'test_multiclassrecall_burned_area': 0.006427088752388954}]

@github-actions github-actions bot added the trainers PyTorch Lightning trainers label Jun 19, 2024
@DimitrisMantas
Copy link
Contributor

DimitrisMantas commented Jun 20, 2024

Given that most metrics of interest are broken (e.g., all of them when average="macro" and ignore_index is specified (Lightning-AI/torchmetrics#2443) andJaccardIndex which outputs NaN when average==macro instead when you try to take absent and ignored classes into account with zero_division (Lightning-AI/torchmetrics#2535)), should we make an effort to see if and how we could add our own?

I'm saying this because these are only the issues I've found so far, but I've also noticed other suspicious things like the fact that my classwise recall values are not the same as those in the confusion matrix when you normalize it with respect to ground truth (I haven't checked if this is also the case with precision, so when the matrix is normalized column-wise). I'm also pretty confident that if all of this is wrong then micro averaging is also probably wrong.

I should be pretty easy to compute all these metrics straight from the confusion matrix (assuming it at least is correct) and I've actually tried to reimplent them this way but it hasn't really been a priority because I’ve found that all these wrong (?) values are basically a lower bound of the actual ones. If you look at the official implementations, this is actually what they are doing, and my guess is that they have a bug in their logic later on. But indeed all these metrics inherit from StatScores, basically the confusion matrix.

I’m actually pretty dumbfounded these issues are not a top priority for the TorchMetrics team and instead they focus on adding to their docs but to each their own…

@robmarkcole
Copy link
Contributor Author

@DimitrisMantas good call on my ignoring the ignore_index.! In fairness they do address issues, but have a long backlog. When I made some noise they addressed Lightning-AI/torchmetrics#2198
My opinion is it is better to work with torchmetrics to address the issues, rather than implement from scratch here. I see your comment at Lightning-AI/torchmetrics#2535 (comment) so perhaps a pragmatic approach is not to add new metrics that we have concerns about, but also to create specific issues which track these concerns

@DimitrisMantas
Copy link
Contributor

Sure, that makes sense; please excuse the rant haha.

@robmarkcole
Copy link
Contributor Author

robmarkcole commented Jun 20, 2024

Applied on_epoch=True, to all steps for consistency - this results in both per epoch and per step being reported for train only - perhaps this is why @isaaccorley did not apply to train?

train_loss_epoch | 0.028535427525639534
train_loss_step | 0.00008003244874998927
train_AverageAccuracy_epoch | 0.9101453423500061
train_AverageAccuracy_step | 0.9124529361724854
image image

Note that Val is unaffected:

val_AverageAccuracy | 0.8227439522743225

For a task with 2 classes there are a grand total of Metrics (52) being reported between train & val

@isaaccorley
Copy link
Collaborator

isaaccorley commented Jun 20, 2024

I just set to be explicit but I think that pytorch lightning or torchmetrics auto sets on_epoch to be False for training and True for all else.

@DimitrisMantas
Copy link
Contributor

You need to set both on_step and on_epoch to get logs only per step or per epoch.

@robmarkcole
Copy link
Contributor Author

@DimitrisMantas now just performing on_step for train loss, so a more manageable 36 metrics now

@robmarkcole
Copy link
Contributor Author

Not sure about this failing test ValueError: Problem with given class_path 'torchgeo.trainers.SemanticSegmentationTask'

@isaaccorley
Copy link
Collaborator

Must be an issue with on of the minimum versions of the package since it's passing for the other tests.

@adamjstewart
Copy link
Collaborator

We can definitely increase the min version of torchmetrics if we need to.

pyproject.toml Outdated Show resolved Hide resolved
@robmarkcole
Copy link
Contributor Author

@adamjstewart I'm inclined to close this PR as I don't feel confident I understand the behaviour of torchmetrics in this implementation. Elsewhere I am using the on_stage_epoch_end hooks and feel confident I do understand the behaviour with that approach. Overall I think this should be a change we make from a place of understanding, and in smaller steps than this PR takes

@robmarkcole
Copy link
Contributor Author

robmarkcole commented Aug 7, 2024

torchmetrics=1.1.0 test errors here

MeanAveragePrecision(), kwargs = {'average': 'macro'}
...
ValueError: Unexpected keyword arguments: `average`

See this was added in 1.1.1

@robmarkcole
Copy link
Contributor Author

After discussion with torchmetrics devs, created Lightning-AI/torchmetrics#2683

@adamjstewart
Copy link
Collaborator

That's such a complicated minimal reproducible example lol.

@adamjstewart
Copy link
Collaborator

I tried making a self-contained minimal reproducible example but couldn't get one working and gave up.

@adamjstewart adamjstewart removed this from the 0.6.0 milestone Aug 27, 2024
@DimitrisMantas
Copy link
Contributor

It just hit me that we should be a bit careful with which metrics we add to avoid unnecessary computation; class-wise accuracy and recall are the same thing and so are micro-averaged accuracy, precision, and recall.

@adamjstewart
Copy link
Collaborator

Any sense of how much these metrics actually add to processing time? If it isn't noticeable by a human, I don't particularly care about the overhead.

@DimitrisMantas
Copy link
Contributor

Haven't measured it but doubt it's much.

@adamjstewart adamjstewart added this to the 0.7.0 milestone Aug 27, 2024
@robmarkcole
Copy link
Contributor Author

I believe Lightning offers tools for profiling

@adamjstewart
Copy link
Collaborator

@robmarkcole
Copy link
Contributor Author

@adamjstewart @DimitrisMantas per this comment we should be using the _epoch_end hooks Lightning-AI/torchmetrics#2683 (comment)

@DimitrisMantas
Copy link
Contributor

DimitrisMantas commented Sep 5, 2024

I see the issue, but I must be missing something because my own code uses the standard logging tools and metric collections work just fine.

Altough by "work", I mean I don't get an error. Other than that, I found out a couple of days ago that the diagonal of my confusion matrix doesn't match the class accuracies (which it should), so I'm obviously not using the API correctly...

Edit: I have at least one mistake where I do self.log_dict(metrics(input, target). The docs says this is wrong.

Edit 2: Aaaaand I finally got your error...

@DimitrisMantas
Copy link
Contributor

Ok, so basically what the torchmetrics guys are saying is that automatic logging is not supported for metric collections?

@DimitrisMantas
Copy link
Contributor

@robmarkcole I can confirm the recommended approach yields consistent results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dependencies Packaging and dependencies trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants