Skip to content

Commit

Permalink
Fix PSNR calculation in I2I and SR workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
danifranco committed Mar 8, 2024
1 parent df82322 commit 6ab226a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions biapy/engine/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ def process_sample(self, norm):
# Calculate PSNR
if pred.dtype == np.dtype('uint16'):
pred = pred.astype(np.float32)
if self._Y.dtype == np.dtype('uint16'):
self._Y = self._Y.astype(np.float32)
if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST:
if self._Y.dtype == np.dtype('uint16'):
self._Y = self._Y.astype(np.float32)
psnr_merge_patches = self.metrics[0](torch.from_numpy(pred), torch.from_numpy(self._Y))
self.stats['psnr_merge_patches'] += psnr_merge_patches

Expand Down
5 changes: 3 additions & 2 deletions biapy/engine/super_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,10 @@ def process_sample(self, norm):
# Calculate PSNR
if pred.dtype == np.dtype('uint16'):
pred = pred.astype(np.float32)
if self._Y.dtype == np.dtype('uint16'):
self._Y = self._Y.astype(np.float32)

if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST:
if self._Y.dtype == np.dtype('uint16'):
self._Y = self._Y.astype(np.float32)
psnr_merge_patches = self.metrics[0](torch.from_numpy(pred), torch.from_numpy(self._Y))
self.stats['psnr_merge_patches'] += psnr_merge_patches

Expand Down

0 comments on commit 6ab226a

Please sign in to comment.