Skip to content

Commit

Permalink
Fix small bug in semantic seg. multiclass jaccard calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
danifranco committed Oct 10, 2024
1 parent e30485d commit 5f9b0aa
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 14 deletions.
8 changes: 0 additions & 8 deletions biapy/engine/base_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,10 +1739,6 @@ def process_test_sample(self):
if self.cfg.MODEL.N_CLASSES > 2 and self.cfg.DATA.TEST.ARGMAX_TO_OUTPUT:
_type = np.uint8 if self.cfg.MODEL.N_CLASSES < 255 else np.uint16
pred = np.expand_dims(np.argmax(pred, -1), -1).astype(_type)
if self.current_sample["Y"] is not None:
self.current_sample["Y"] = np.expand_dims(np.argmax(self.current_sample["Y"], -1), -1).astype(
_type
)

# Calculate the metrics
if self.current_sample["Y"] is not None:
Expand Down Expand Up @@ -1858,10 +1854,6 @@ def process_test_sample(self):
if self.cfg.MODEL.N_CLASSES > 2 and self.cfg.DATA.TEST.ARGMAX_TO_OUTPUT:
_type = np.uint8 if self.cfg.MODEL.N_CLASSES < 255 else np.uint16
pred = np.expand_dims(np.argmax(pred, -1), -1).astype(_type)
if self.current_sample["Y"] is not None:
self.current_sample["Y"] = np.expand_dims(np.argmax(self.current_sample["Y"], -1), -1).astype(
_type
)

if self.cfg.TEST.POST_PROCESSING.APPLY_MASK:
pred = apply_binary_mask(pred, self.cfg.DATA.TEST.BINARY_MASKS)
Expand Down
12 changes: 6 additions & 6 deletions biapy/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,12 @@ def __call__(self, y_pred, y_true):
y_true = torch.cat((1 - y_true, y_true), 1)

if self.num_classes > 2:
return self.jaccard(
y_pred,
(y_true.squeeze() if y_true.shape[0] > 1 else y_true.squeeze().unsqueeze(0)),
)
else:
return self.jaccard(y_pred, y_true)
if y_pred.shape[1] > 1:
y_true = y_true.squeeze()
if len(y_pred.shape)-2 == len(y_true.shape):
y_true = y_true.unsqueeze(0)

return self.jaccard(y_pred, y_true)


class instance_metrics:
Expand Down

0 comments on commit 5f9b0aa

Please sign in to comment.