-
Notifications
You must be signed in to change notification settings - Fork 364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support calculating loss in the validation step #1486
Comments
Hi, I checked the current #1503 work. MMEngineTo be more general and follow the current implementation style, In the def val_step(self, data: Union[tuple, dict, list], loss: bool = False) -> Union[list, Tuple[list, Dict[str, torch.Tensor]]]:
"""Gets the predictions of given data.
Calls ``self.data_preprocessor(data, False)`` and
``self(inputs, data_sample, mode='predict')`` in order. Return the
predictions which will be passed to evaluator.
Args:
data (dict or tuple or list): Data sampled from dataset.
Returns:
list: The predictions of given data.
"""
data = self.data_preprocessor(data, False)
if loss:
return self._run_forward(data, mode='loss_and_predict') # type: ignore
else:
return self._run_forward(data, mode='predict') # type: ignore
def test_step(self, data: Union[dict, tuple, list], loss: bool = False) -> Union[list, Tuple[list, Dict[str, torch.Tensor]]]:
"""``BaseModel`` implements ``test_step`` the same as ``val_step``.
Args:
data (dict or tuple or list): Data sampled from dataset.
Returns:
list: The predictions of given data.
"""
data = self.data_preprocessor(data, False)
if loss:
return self._run_forward(data, mode='loss_and_predict') # type: ignore
else:
return self._run_forward(data, mode='predict') # type: ignore Downstream librariesAs for the implementation of the Although we need to infer twice to get the results, I think there is no choice to do it due to the backward compatibility. def loss_and_predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Predict results and calculate losses from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
batch_data_samples (List[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
rescale (bool): Whether to rescale the results.
Defaults to True.
Returns:
list[:obj:`DetDataSample`]: Detection results of the input images.
Each DetDataSample usually contain 'pred_instances'. And the
`pred_instances` usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
dict: A dictionary of loss components
"""
preds = self.predict(batch_inputs, batch_data_samples, rescale=rescale)
losses = self.loss(batch_inputs, batch_data_samples)
return preds, losses I think this is a more elegant and ideal solution. |
Hi @guyleaf , we have discussed your proposal before. Your solution is elegant but it requires the model forward twice. |
What is the feature?
We often receive requests from users who want to be able to print the loss during the validation phase, but MMEngine does not support this feature at the moment. If you also have a need for this, feel free to discuss it.
There are two possible solutions for MMEngine to support this feature.
One is to add a LossMetric to the downstream library, the forward method of the downstream library will still return a list of DataElement, and the LossMetric can calculate the loss by the information in the DataElement.
The other is that the forward method of the downstream library returns a dictionary, and the downstream library needs to finish the loss calculation in forward and return it to MMEngine.
https://github.com/open-mmlab/mmpretrain/blob/17a886cb5825cd8c26df4e65f7112d404b99fe12/mmpretrain/models/classifiers/image.py#L87
References
Any other context?
No response
The text was updated successfully, but these errors were encountered: