diff --git a/biapy/data/pre_processing.py b/biapy/data/pre_processing.py index 4cd31182..7f35f184 100644 --- a/biapy/data/pre_processing.py +++ b/biapy/data/pre_processing.py @@ -846,7 +846,7 @@ def norm_range01(x, dtype=np.float32): norm_steps = {} norm_steps['orig_dtype'] = x.dtype - if x.dtype == np.uint8 or x.dtype == torch.uint8: + if x.dtype in [np.uint8, torch.uint8]: x = x/255 norm_steps['div'] = 1 else: @@ -867,20 +867,27 @@ def norm_range01(x, dtype=np.float32): def undo_norm_range01(x, xnorm): if 'div' == xnorm['type']: - x = (x*255) - if isinstance(x, np.ndarray): - x = x.astype(np.uint8) + # Prevent values go outside expected range + if isinstance(x, np.ndarray): + x = np.clip(x, 0, 1) else: - x = x.to(torch.uint8) - reductions = [key for key, value in xnorm.items() if 'reduced' in key.lower()] - if len(reductions)>0: - reductions = reductions[0] - reductions = reductions.replace('reduced_','') - x = (x*65535) - if isinstance(x, np.ndarray): - x = x.astype(eval("np.{}".format(reductions) )) + x = torch.clamp(x, 0, 1) + if 'div' in xnorm: + x = (x*255) + if isinstance(x, np.ndarray): + x = x.astype(np.uint8) + else: + x = x.to(torch.uint8) else: - x = x.to(eval("torch.{}".format(reductions) )) + reductions = [key for key, value in xnorm.items() if 'reduced' in key.lower()] + if len(reductions)>0: + reductions = reductions[0] + reductions = reductions.replace('reduced_','') + x = (x*65535) + if isinstance(x, np.ndarray): + x = x.astype(eval("np.{}".format(reductions) )) + else: + x = x.to(eval("torch.{}".format(reductions) )) return x def reduce_dtype(x, x_min, x_max, out_min=0, out_max=1, out_type=np.float32):