Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log_dict to support ClasswiseWrapper #2683

Closed
robmarkcole opened this issue Aug 8, 2024 · 4 comments · Fixed by #2720
Closed

log_dict to support ClasswiseWrapper #2683

robmarkcole opened this issue Aug 8, 2024 · 4 comments · Fixed by #2720
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@robmarkcole
Copy link

🚀 Feature

I want to be able to call:

self.my_metrics(y_hat, y)
self.log_dict(self.my_metrics)

where my_metrics is a MetricCollection that includes ClasswiseWrapper metrics. Currently this will fail with an error like:

ValueError: The `.compute()` return of the metric logged as 'val_Accuracy' must be a tensor. Found {'multiclassaccuracy_background': tensor(0.9754, device='cuda:0'),`

An example of how to generate metrics that will currently fail from this PR:

        metric_classes = {
            'Accuracy': Accuracy,
            'F1Score': FBetaScore,
            'JaccardIndex': JaccardIndex,
            'Precision': Precision,
            'Recall': Recall,
        }

        metrics_dict = {}

        # Loop through the types of averaging
        for average in ['micro', 'macro']:
            for metric_name, metric_class in metric_classes.items():
                name = (
                    f'Overall{metric_name}'
                    if average == 'micro'
                    else f'Average{metric_name}'
                )
                params = {
                    'task': 'multiclass',
                    'num_classes': num_classes,
                    'average': average,
                    'ignore_index': ignore_index,
                }
                if metric_name in ['Accuracy', 'F1Score', 'Precision', 'Recall']:
                    params['multidim_average'] = 'global'
                if metric_name == 'F1Score':
                    params['beta'] = 1.0
                metrics_dict[name] = metric_class(**params)

        # Loop through the classwise metrics
        for metric_name, metric_class in metric_classes.items():
            if metric_name != 'JaccardIndex':
                metrics_dict[metric_name] = ClasswiseWrapper(
                    metric_class(
                        task='multiclass',
                        num_classes=num_classes,
                        average='none',
                        multidim_average='global',
                        ignore_index=ignore_index,
                    ),
                    labels=labels,
                )
            else:
                metrics_dict[metric_name] = ClasswiseWrapper(
                    metric_class(
                        task='multiclass',
                        num_classes=num_classes,
                        average='none',
                        ignore_index=ignore_index,
                    ),
                    labels=labels,
                )

        self.train_metrics = MetricCollection(metrics_dict, prefix='train_')
@robmarkcole robmarkcole added the enhancement New feature or request label Aug 8, 2024
@Borda Borda added the help wanted Extra attention is needed label Aug 21, 2024
@Borda Borda changed the title log_dict to support ClasswiseWrapper log_dict to support ClasswiseWrapper Aug 21, 2024
@SkafteNicki
Copy link
Member

Hi @robmarkcole, thanks for raising this issue, I finally had some time to tackle it.

I have some bad and good news. The bad news is that I do not think it will be possible to directly integrate log_dict with ClasswiseWrapper and MetricCollection in the way you are envision. The good news is that there is still a official way to get it working, but it require some code changes on your side.

Why integrating ClasswiseWrapper with log_dict is not possible

I try to explain my reasoning here to the best of my ability. First of, self.log_dict method is nothing more than a wrapper around the log method:

def log_dict(self, dictionary):
    for k,v in dictionary.items():
        self.log(k,v)

the exact code is here: https://github.com/Lightning-AI/pytorch-lightning/blob/f3f10d460338ca8b2901d5cd43456992131767ec/src/lightning/pytorch/core/module.py#L547-L625
dictionary can either be a standard python dictionary or a MetricCollection. Secondly, self.log only support logging of scalar-tensors, python floats and torchmetric Metrics that return a scalar tensor. It is really the last part that is the blocker here. As an example self.log cannot be used to log the output of ConfusionMatrix because it returns a [num_classes, num_classes] tensor. So when self.log_dict is called on a MetricCollection the individually metrics in the collection needs to be compatible with self.log
Since ClasswiseWrapper converts the output of a metric into a dictionary it is not compatible with self.log. A solution to solve this would be to make self.log compatible with dictionaries, but that would defeat the purpose of self.log_dict. In general I think the self.log/self.log_dict API is considered stable so we cannot really touch that.

How to work around this then

If you take a look at the documentation for logging of torchmetrics in Pytorch lightning there are two official ways:
https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html#logging-torchmetrics

  • Either call self.log/self.log_dict directly on the metric objects and lightning internally takes care of calling the compute method at the right time. We can call this automatic logging.
  • Call self.log/self.log_dict on the result computed by calling compute method yourself, similar to any other thing you may want to log. We can call this manual logging.

Currently, this issue is about getting the automatic logging to work with ClasswiseWrapper and self.log_dict. The good news is that it already works in the manual logging way. I tried it out in this PR #2720 and it works because MetricCollection has the feature of unpacking nested dictionaries returned by compute method. Thus, if we have a metric collection like this:

train_metrics = MetricCollection(
    {
        "macro_accuracy": MulticlassAccuracy(num_classes=3, average="macro"),
        "classwise_accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)),
    }
)

calling compute does not result in this dictionary nested

{"macro_accuracy": ..., {"multiclassaccuracy_0": ..., "multiclassaccuracy_1": ..., "multiclassaccuracy_2": ...} }

instead it will return a flattened dictionary

{"macro_accuracy": ..., "multiclassaccuracy_0": ..., "multiclassaccuracy_1": ..., "multiclassaccuracy_2": ...} 

which can be directly logged using self.log_dict. You can checkout this integration test:

class TestModel(BoringModel):
def __init__(self) -> None:
super().__init__()
self.train_metrics = MetricCollection(
{
"macro_accuracy": MulticlassAccuracy(num_classes=5, average="macro"),
"classwise_accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=5, average=None)),
},
prefix="train_",
)
self.val_metrics = MetricCollection(
{
"macro_accuracy": MulticlassAccuracy(num_classes=5, average="macro"),
"classwise_accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=5, average=None)),
},
prefix="val_",
)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)
self.train_metrics.update(preds, target)
batch_values = self.train_metrics.compute()
self.log_dict(batch_values, on_step=True, on_epoch=False)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)
self.val_metrics.update(preds, target)
def on_validation_epoch_end(self):
self.log_dict(self.val_metrics.compute(), on_step=False, on_epoch=True)

for an example of how this manual logging would look like in Pytorch Lightning using MetricCollection and ClasswiseWrapper.

I hope this is enough feedback on this issue. Else feel free to reopen the issue.

@robmarkcole
Copy link
Author

@SkafteNicki thanks for the explanation!
Can you also confirm that if I do not implement _epoch_end hooks I can achieve the same result with

     def validation_step(self, batch, batch_idx): 
         preds = torch.randint(0, 5, (100,), device=batch.device) 
         target = torch.randint(0, 5, (100,), device=batch.device) 
  
         self.val_metrics.update(preds, target) 
         self.log_dict(self.val_metrics.compute(), on_step=False, on_epoch=True) # Only called on_epoch end, and reset is called after

@SkafteNicki
Copy link
Member

@robmarkcole i am pretty sure you will need the _epoch_end sadly. Reason being that when self.log_dict/self.log receives a metric object it will automatically do aggregation and reset correctly. For the manually logging, when you just provide the scalar tensor values lightning do not know how these "correctly" should be aggregated and thus will take a mean over the logged values, which in some cases is the correct aggregation but in most is not.

The correct way is therefore (for manually logging):

def validation_step(self, batch, batch_idx): 
     preds = torch.randint(0, 5, (100,), device=batch.device) 
     target = torch.randint(0, 5, (100,), device=batch.device) 
  
     self.val_metrics.update(preds, target) 

def on_validation_step_epoch_end(self):
     self.log_dict(self.val_metrics.compute())
     self.val_metrics.reset()

(I actually missed the reset in the integration test because I forgot to test for more than one epoch. If this is not done, values will keep aggregating)

@DimitrisMantas
Copy link

DimitrisMantas commented Sep 5, 2024

Sorry to jump in, but what about calling self.log_dict(self.metrics(input, target)) in the *_step hooks?

Edit: The docs says this is wrong.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
4 participants