diff --git a/biapy/engine/image_to_image.py b/biapy/engine/image_to_image.py index 8b065bdd..1250ea97 100644 --- a/biapy/engine/image_to_image.py +++ b/biapy/engine/image_to_image.py @@ -199,9 +199,9 @@ def process_sample(self, norm): # Calculate PSNR if pred.dtype == np.dtype('uint16'): pred = pred.astype(np.float32) - if self._Y.dtype == np.dtype('uint16'): - self._Y = self._Y.astype(np.float32) if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST: + if self._Y.dtype == np.dtype('uint16'): + self._Y = self._Y.astype(np.float32) psnr_merge_patches = self.metrics[0](torch.from_numpy(pred), torch.from_numpy(self._Y)) self.stats['psnr_merge_patches'] += psnr_merge_patches diff --git a/biapy/engine/super_resolution.py b/biapy/engine/super_resolution.py index 8f270259..7307a905 100644 --- a/biapy/engine/super_resolution.py +++ b/biapy/engine/super_resolution.py @@ -207,9 +207,10 @@ def process_sample(self, norm): # Calculate PSNR if pred.dtype == np.dtype('uint16'): pred = pred.astype(np.float32) - if self._Y.dtype == np.dtype('uint16'): - self._Y = self._Y.astype(np.float32) + if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST: + if self._Y.dtype == np.dtype('uint16'): + self._Y = self._Y.astype(np.float32) psnr_merge_patches = self.metrics[0](torch.from_numpy(pred), torch.from_numpy(self._Y)) self.stats['psnr_merge_patches'] += psnr_merge_patches