diff --git a/biapy/engine/base_workflow.py b/biapy/engine/base_workflow.py index 76938fe3..ed08f7a4 100644 --- a/biapy/engine/base_workflow.py +++ b/biapy/engine/base_workflow.py @@ -1739,10 +1739,6 @@ def process_test_sample(self): if self.cfg.MODEL.N_CLASSES > 2 and self.cfg.DATA.TEST.ARGMAX_TO_OUTPUT: _type = np.uint8 if self.cfg.MODEL.N_CLASSES < 255 else np.uint16 pred = np.expand_dims(np.argmax(pred, -1), -1).astype(_type) - if self.current_sample["Y"] is not None: - self.current_sample["Y"] = np.expand_dims(np.argmax(self.current_sample["Y"], -1), -1).astype( - _type - ) # Calculate the metrics if self.current_sample["Y"] is not None: @@ -1858,10 +1854,6 @@ def process_test_sample(self): if self.cfg.MODEL.N_CLASSES > 2 and self.cfg.DATA.TEST.ARGMAX_TO_OUTPUT: _type = np.uint8 if self.cfg.MODEL.N_CLASSES < 255 else np.uint16 pred = np.expand_dims(np.argmax(pred, -1), -1).astype(_type) - if self.current_sample["Y"] is not None: - self.current_sample["Y"] = np.expand_dims(np.argmax(self.current_sample["Y"], -1), -1).astype( - _type - ) if self.cfg.TEST.POST_PROCESSING.APPLY_MASK: pred = apply_binary_mask(pred, self.cfg.DATA.TEST.BINARY_MASKS) diff --git a/biapy/engine/metrics.py b/biapy/engine/metrics.py index 9bb9e1d9..a50157ab 100644 --- a/biapy/engine/metrics.py +++ b/biapy/engine/metrics.py @@ -183,12 +183,12 @@ def __call__(self, y_pred, y_true): y_true = torch.cat((1 - y_true, y_true), 1) if self.num_classes > 2: - return self.jaccard( - y_pred, - (y_true.squeeze() if y_true.shape[0] > 1 else y_true.squeeze().unsqueeze(0)), - ) - else: - return self.jaccard(y_pred, y_true) + if y_pred.shape[1] > 1: + y_true = y_true.squeeze() + if len(y_pred.shape)-2 == len(y_true.shape): + y_true = y_true.unsqueeze(0) + + return self.jaccard(y_pred, y_true) class instance_metrics: