Skip to content

Commit

Permalink
Add clip into undo_norm_range01 as it's the expected value range
Browse files Browse the repository at this point in the history
  • Loading branch information
danifranco committed Feb 27, 2024
1 parent 3599a7c commit 0b8b98c
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions biapy/data/pre_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def norm_range01(x, dtype=np.float32):
norm_steps = {}
norm_steps['orig_dtype'] = x.dtype

if x.dtype == np.uint8 or x.dtype == torch.uint8:
if x.dtype in [np.uint8, torch.uint8]:
x = x/255
norm_steps['div'] = 1
else:
Expand All @@ -867,20 +867,27 @@ def norm_range01(x, dtype=np.float32):

def undo_norm_range01(x, xnorm):
if 'div' == xnorm['type']:
x = (x*255)
if isinstance(x, np.ndarray):
x = x.astype(np.uint8)
# Prevent values go outside expected range
if isinstance(x, np.ndarray):
x = np.clip(x, 0, 1)
else:
x = x.to(torch.uint8)
reductions = [key for key, value in xnorm.items() if 'reduced' in key.lower()]
if len(reductions)>0:
reductions = reductions[0]
reductions = reductions.replace('reduced_','')
x = (x*65535)
if isinstance(x, np.ndarray):
x = x.astype(eval("np.{}".format(reductions) ))
x = torch.clamp(x, 0, 1)
if 'div' in xnorm:
x = (x*255)
if isinstance(x, np.ndarray):
x = x.astype(np.uint8)
else:
x = x.to(torch.uint8)
else:
x = x.to(eval("torch.{}".format(reductions) ))
reductions = [key for key, value in xnorm.items() if 'reduced' in key.lower()]
if len(reductions)>0:
reductions = reductions[0]
reductions = reductions.replace('reduced_','')
x = (x*65535)
if isinstance(x, np.ndarray):
x = x.astype(eval("np.{}".format(reductions) ))
else:
x = x.to(eval("torch.{}".format(reductions) ))
return x

def reduce_dtype(x, x_min, x_max, out_min=0, out_max=1, out_type=np.float32):
Expand Down

0 comments on commit 0b8b98c

Please sign in to comment.