From a0f3ee1e03b9e5a7d749ca029d4e8dcecc02adc1 Mon Sep 17 00:00:00 2001 From: Daniel Franco Date: Sun, 25 Feb 2024 19:31:55 +0100 Subject: [PATCH] Update SR upscaling to other workflows and adapt random patch extraction in pair data generator --- biapy/data/data_3D_manipulation.py | 2 +- biapy/data/generators/augmentors.py | 34 ++++++++++++++--------------- biapy/engine/check_configuration.py | 8 ++++++- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/biapy/data/data_3D_manipulation.py b/biapy/data/data_3D_manipulation.py index 17d33cb1..adec1104 100644 --- a/biapy/data/data_3D_manipulation.py +++ b/biapy/data/data_3D_manipulation.py @@ -8,7 +8,7 @@ from biapy.utils.util import load_3d_images_from_dir, order_dimensions def load_and_prepare_3D_data(train_path, train_mask_path, cross_val=False, cross_val_nsplits=5, cross_val_fold=1, - val_split=0.1, seed=0, shuffle_val=True, crop_shape=(80, 80, 80, 1), y_upscaling=1, random_crops_in_DA=False, + val_split=0.1, seed=0, shuffle_val=True, crop_shape=(80, 80, 80, 1), y_upscaling=(1,1,1), random_crops_in_DA=False, ov=(0,0,0), padding=(0,0,0), minimum_foreground_perc=-1, reflect_to_complete_shape=False, convert_to_rgb=False, preprocess_cfg=None, is_y_mask=False, preprocess_f=None): """ diff --git a/biapy/data/generators/augmentors.py b/biapy/data/generators/augmentors.py index a7fa0770..fd96cc6e 100644 --- a/biapy/data/generators/augmentors.py +++ b/biapy/data/generators/augmentors.py @@ -1020,7 +1020,7 @@ def GridMask(img, channels, z_size, ratio=0.6, d_range=(30,60), rotate=1, invert def random_crop_pair(image, mask, random_crop_size, val=False, draw_prob_map_points=False, img_prob=None, weight_map=None, - scale=1): + scale=(1,1)): """Random crop for an image and its mask. No crop is done in those dimensions that ``random_crop_size`` is greater than the input image shape in those dimensions. For instance, if an input image is ``400x150`` and ``random_crop_size`` is ``224x224`` the resulting image will be ``224x150``. @@ -1049,8 +1049,8 @@ def random_crop_pair(image, mask, random_crop_size, val=False, draw_prob_map_poi weight_map : bool, optional Weight map of the given image. E.g. ``(y, x, channels)``. - scale : int, optional - Scale factor the second image given. + scale : tuple of 2 ints, optional + Scale factor the second image given. E.g. ``(2,2)``. Returns ------- @@ -1116,26 +1116,26 @@ def random_crop_pair(image, mask, random_crop_size, val=False, draw_prob_map_poi y = np.random.randint(0, height - dy + 1) if height - dy +1 > 0 else 0 # Super-resolution check - if scale != 1: + if any([x != 1 for x in scale]): img_out_shape = img[y:(y+dy), x:(x+dx)].shape - mask_out_shape = mask[y*scale:(y+dy)*scale, x*scale:(x+dx)*scale].shape - s = [img_out_shape[0]*scale, img_out_shape[1]*scale] + mask_out_shape = mask[y*scale[0]:(y+dy)*scale[0], x*scale[1]:(x+dx)*scale[1]].shape + s = [img_out_shape[0]*scale[0], img_out_shape[1]*scale[1]] if all(x!=y for x,y in zip(s,mask_out_shape)): raise ValueError("Images can not be cropped to a PATCH_SIZE of {}. Inputs: LR image shape={} " "and HR image shape={}. When cropping the output shapes are {} and {}, for LR and HR images respectively. " "Try to reduce DATA.PATCH_SIZE".format(random_crop_size, img.shape, mask.shape, img_out_shape, mask_out_shape)) if draw_prob_map_points == True: - return img[y:(y+dy), x:(x+dx)], mask[y*scale:(y+dy)*scale, x*scale:(x+dx)*scale], oy, ox, y, x + return img[y:(y+dy), x:(x+dx)], mask[y*scale[0]:(y+dy)*scale[0], x*scale[1]:(x+dx)*scale[1]], oy, ox, y, x else: if weight_map is not None: - return img[y:(y+dy), x:(x+dx)], mask[y*scale:(y+dy)*scale, x*scale:(x+dx)*scale], weight_map[y:(y+dy), x:(x+dx)] + return img[y:(y+dy), x:(x+dx)], mask[y*scale[0]:(y+dy)*scale[0], x*scale[1]:(x+dx)*scale[1]], weight_map[y:(y+dy), x:(x+dx)] else: - return img[y:(y+dy), x:(x+dx)], mask[y*scale:(y+dy)*scale, x*scale:(x+dx)*scale] + return img[y:(y+dy), x:(x+dx)], mask[y*scale[0]:(y+dy)*scale[0], x*scale[1]:(x+dx)*scale[1]] def random_3D_crop_pair(image, mask, random_crop_size, val=False, img_prob=None, weight_map=None, draw_prob_map_points=False, - scale=1): + scale=(1,1,1)): """Extracts a random 3D patch from the given image and mask. No crop is done in those dimensions that ``random_crop_size`` is greater than the input image shape in those dimensions. For instance, if an input image is ``10x400x150`` and ``random_crop_size`` is ``10x224x224`` the resulting image will be ``10x224x150``. @@ -1164,8 +1164,8 @@ def random_3D_crop_pair(image, mask, random_crop_size, val=False, img_prob=None, draw_prob_map_points : bool, optional To return the voxel chosen to be the center of the crop. - scale : int, optional - Scale factor the second image given. + scale : tuple of 3 ints, optional + Scale factor the second image given. E.g. ``(2,4,4)``. Returns ------- @@ -1252,21 +1252,21 @@ def random_3D_crop_pair(image, mask, random_crop_size, val=False, img_prob=None, x = np.random.randint(0, rows - dx + 1) if rows - dx +1 > 0 else 0 # Super-resolution check - if scale != 1: + if any([x != 1 for x in scale]): img_out_shape = vol[z:(z+dz), y:(y+dy), x:(x+dx)].shape - mask_out_shape = mask[z:(z+dz), y*scale:(y+dy)*scale, x*scale:(x+dx)*scale].shape - s = [img_out_shape[0], img_out_shape[1]*scale, img_out_shape[2]*scale] + mask_out_shape = mask[z*scale[0]:(z+dz)*scale[0], y*scale[1]:(y+dy)*scale[1], x*scale[2]:(x+dx)*scale[2]].shape + s = [img_out_shape[0]*scale[0], img_out_shape[1]*scale[1], img_out_shape[2]*scale[2]] if all(x!=y for x,y in zip(s,mask_out_shape)): raise ValueError("Images can not be cropped to a PATCH_SIZE of {}. Inputs: LR image shape={} " "and HR image shape={}. When cropping the output shapes are {} and {}, for LR and HR images respectively. " "Try to reduce DATA.PATCH_SIZE".format(random_crop_size, vol.shape, mask.shape, img_out_shape, mask_out_shape)) if draw_prob_map_points: - return vol[z:(z+dz), y:(y+dy), x:(x+dx)], mask[z:(z+dz), y*scale:(y+dy)*scale, x*scale:(x+dx)*scale],\ + return vol[z:(z+dz), y:(y+dy), x:(x+dx)], mask[z*scale[0]:(z+dz)*scale[0], y*scale[1]:(y+dy)*scale[1], x*scale[2]:(x+dx)*scale[2]],\ oz, oy, ox, z, y, x else: if weight_map is not None: - return vol[z:(z+dz), y:(y+dy), x:(x+dx)], mask[z:(z+dz), y*scale:(y+dy)*scale, x*scale:(x+dx)*scale],\ + return vol[z:(z+dz), y:(y+dy), x:(x+dx)], mask[z*scale[0]:(z+dz)*scale[0], y*scale[1]:(y+dy)*scale[1], x*scale[2]:(x+dx)*scale[2]],\ weight_map[z:(z+dz), y:(y+dy), x:(x+dx)] else: return vol[z:(z+dz), y:(y+dy), x:(x+dx)], mask[z:(z+dz), y:(y+dy), x:(x+dx)] diff --git a/biapy/engine/check_configuration.py b/biapy/engine/check_configuration.py index 429fb8d2..9503f083 100644 --- a/biapy/engine/check_configuration.py +++ b/biapy/engine/check_configuration.py @@ -306,7 +306,9 @@ def check_configuration(cfg, jobname, check_data_paths=True): elif cfg.PROBLEM.TYPE == 'SUPER_RESOLUTION': if not( cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING ): raise ValueError("Resolution scale must be provided with 'PROBLEM.SUPER_RESOLUTION.UPSCALING' variable") - assert all( i > 0 for i in cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING), "PROBLEM.SUPER_RESOLUTION.UPSCALING are not positive integers" + assert all( i > 0 for i in cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING), "'PROBLEM.SUPER_RESOLUTION.UPSCALING' are not positive integers" + if len(cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING) != dim_count: + raise ValueError(f"'PROBLEM.SUPER_RESOLUTION.UPSCALING' needs to be a tuple of {dim_count} integers") if cfg.MODEL.SOURCE == "torchvision": raise ValueError("'MODEL.SOURCE' as 'torchvision' is not available in super-resolution workflow") if cfg.DATA.NORMALIZATION.TYPE != "div": @@ -552,6 +554,10 @@ def check_configuration(cfg, jobname, check_data_paths=True): elif len(cfg.MODEL.FEATURE_MAPS)-1 != len(cfg.MODEL.Z_DOWN): raise ValueError("'MODEL.FEATURE_MAPS' length minus one and 'MODEL.Z_DOWN' length must be equal") + # Correct UPSCALING for other workflows than SR + if len(cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING) == 0: + opts.extend(['PROBLEM.SUPER_RESOLUTION.UPSCALING', (1,)*dim_count]) + if len(opts) > 0: cfg.merge_from_list(opts)