Skip to content

Commit

Permalink
Update SR upscaling to other workflows and adapt random patch extract…
Browse files Browse the repository at this point in the history
…ion in pair data generator
  • Loading branch information
danifranco committed Feb 25, 2024
1 parent 40b2695 commit a0f3ee1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
2 changes: 1 addition & 1 deletion biapy/data/data_3D_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from biapy.utils.util import load_3d_images_from_dir, order_dimensions

def load_and_prepare_3D_data(train_path, train_mask_path, cross_val=False, cross_val_nsplits=5, cross_val_fold=1,
val_split=0.1, seed=0, shuffle_val=True, crop_shape=(80, 80, 80, 1), y_upscaling=1, random_crops_in_DA=False,
val_split=0.1, seed=0, shuffle_val=True, crop_shape=(80, 80, 80, 1), y_upscaling=(1,1,1), random_crops_in_DA=False,
ov=(0,0,0), padding=(0,0,0), minimum_foreground_perc=-1, reflect_to_complete_shape=False, convert_to_rgb=False,
preprocess_cfg=None, is_y_mask=False, preprocess_f=None):
"""
Expand Down
34 changes: 17 additions & 17 deletions biapy/data/generators/augmentors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ def GridMask(img, channels, z_size, ratio=0.6, d_range=(30,60), rotate=1, invert


def random_crop_pair(image, mask, random_crop_size, val=False, draw_prob_map_points=False, img_prob=None, weight_map=None,
scale=1):
scale=(1,1)):
"""Random crop for an image and its mask. No crop is done in those dimensions that ``random_crop_size`` is greater than
the input image shape in those dimensions. For instance, if an input image is ``400x150`` and ``random_crop_size``
is ``224x224`` the resulting image will be ``224x150``.
Expand Down Expand Up @@ -1049,8 +1049,8 @@ def random_crop_pair(image, mask, random_crop_size, val=False, draw_prob_map_poi
weight_map : bool, optional
Weight map of the given image. E.g. ``(y, x, channels)``.
scale : int, optional
Scale factor the second image given.
scale : tuple of 2 ints, optional
Scale factor the second image given. E.g. ``(2,2)``.
Returns
-------
Expand Down Expand Up @@ -1116,26 +1116,26 @@ def random_crop_pair(image, mask, random_crop_size, val=False, draw_prob_map_poi
y = np.random.randint(0, height - dy + 1) if height - dy +1 > 0 else 0

# Super-resolution check
if scale != 1:
if any([x != 1 for x in scale]):
img_out_shape = img[y:(y+dy), x:(x+dx)].shape
mask_out_shape = mask[y*scale:(y+dy)*scale, x*scale:(x+dx)*scale].shape
s = [img_out_shape[0]*scale, img_out_shape[1]*scale]
mask_out_shape = mask[y*scale[0]:(y+dy)*scale[0], x*scale[1]:(x+dx)*scale[1]].shape
s = [img_out_shape[0]*scale[0], img_out_shape[1]*scale[1]]
if all(x!=y for x,y in zip(s,mask_out_shape)):
raise ValueError("Images can not be cropped to a PATCH_SIZE of {}. Inputs: LR image shape={} "
"and HR image shape={}. When cropping the output shapes are {} and {}, for LR and HR images respectively. "
"Try to reduce DATA.PATCH_SIZE".format(random_crop_size, img.shape, mask.shape, img_out_shape, mask_out_shape))

if draw_prob_map_points == True:
return img[y:(y+dy), x:(x+dx)], mask[y*scale:(y+dy)*scale, x*scale:(x+dx)*scale], oy, ox, y, x
return img[y:(y+dy), x:(x+dx)], mask[y*scale[0]:(y+dy)*scale[0], x*scale[1]:(x+dx)*scale[1]], oy, ox, y, x
else:
if weight_map is not None:
return img[y:(y+dy), x:(x+dx)], mask[y*scale:(y+dy)*scale, x*scale:(x+dx)*scale], weight_map[y:(y+dy), x:(x+dx)]
return img[y:(y+dy), x:(x+dx)], mask[y*scale[0]:(y+dy)*scale[0], x*scale[1]:(x+dx)*scale[1]], weight_map[y:(y+dy), x:(x+dx)]
else:
return img[y:(y+dy), x:(x+dx)], mask[y*scale:(y+dy)*scale, x*scale:(x+dx)*scale]
return img[y:(y+dy), x:(x+dx)], mask[y*scale[0]:(y+dy)*scale[0], x*scale[1]:(x+dx)*scale[1]]


def random_3D_crop_pair(image, mask, random_crop_size, val=False, img_prob=None, weight_map=None, draw_prob_map_points=False,
scale=1):
scale=(1,1,1)):
"""Extracts a random 3D patch from the given image and mask. No crop is done in those dimensions that ``random_crop_size`` is
greater than the input image shape in those dimensions. For instance, if an input image is ``10x400x150`` and ``random_crop_size``
is ``10x224x224`` the resulting image will be ``10x224x150``.
Expand Down Expand Up @@ -1164,8 +1164,8 @@ def random_3D_crop_pair(image, mask, random_crop_size, val=False, img_prob=None,
draw_prob_map_points : bool, optional
To return the voxel chosen to be the center of the crop.
scale : int, optional
Scale factor the second image given.
scale : tuple of 3 ints, optional
Scale factor the second image given. E.g. ``(2,4,4)``.
Returns
-------
Expand Down Expand Up @@ -1252,21 +1252,21 @@ def random_3D_crop_pair(image, mask, random_crop_size, val=False, img_prob=None,
x = np.random.randint(0, rows - dx + 1) if rows - dx +1 > 0 else 0

# Super-resolution check
if scale != 1:
if any([x != 1 for x in scale]):
img_out_shape = vol[z:(z+dz), y:(y+dy), x:(x+dx)].shape
mask_out_shape = mask[z:(z+dz), y*scale:(y+dy)*scale, x*scale:(x+dx)*scale].shape
s = [img_out_shape[0], img_out_shape[1]*scale, img_out_shape[2]*scale]
mask_out_shape = mask[z*scale[0]:(z+dz)*scale[0], y*scale[1]:(y+dy)*scale[1], x*scale[2]:(x+dx)*scale[2]].shape
s = [img_out_shape[0]*scale[0], img_out_shape[1]*scale[1], img_out_shape[2]*scale[2]]
if all(x!=y for x,y in zip(s,mask_out_shape)):
raise ValueError("Images can not be cropped to a PATCH_SIZE of {}. Inputs: LR image shape={} "
"and HR image shape={}. When cropping the output shapes are {} and {}, for LR and HR images respectively. "
"Try to reduce DATA.PATCH_SIZE".format(random_crop_size, vol.shape, mask.shape, img_out_shape, mask_out_shape))

if draw_prob_map_points:
return vol[z:(z+dz), y:(y+dy), x:(x+dx)], mask[z:(z+dz), y*scale:(y+dy)*scale, x*scale:(x+dx)*scale],\
return vol[z:(z+dz), y:(y+dy), x:(x+dx)], mask[z*scale[0]:(z+dz)*scale[0], y*scale[1]:(y+dy)*scale[1], x*scale[2]:(x+dx)*scale[2]],\
oz, oy, ox, z, y, x
else:
if weight_map is not None:
return vol[z:(z+dz), y:(y+dy), x:(x+dx)], mask[z:(z+dz), y*scale:(y+dy)*scale, x*scale:(x+dx)*scale],\
return vol[z:(z+dz), y:(y+dy), x:(x+dx)], mask[z*scale[0]:(z+dz)*scale[0], y*scale[1]:(y+dy)*scale[1], x*scale[2]:(x+dx)*scale[2]],\
weight_map[z:(z+dz), y:(y+dy), x:(x+dx)]
else:
return vol[z:(z+dz), y:(y+dy), x:(x+dx)], mask[z:(z+dz), y:(y+dy), x:(x+dx)]
Expand Down
8 changes: 7 additions & 1 deletion biapy/engine/check_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ def check_configuration(cfg, jobname, check_data_paths=True):
elif cfg.PROBLEM.TYPE == 'SUPER_RESOLUTION':
if not( cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING ):
raise ValueError("Resolution scale must be provided with 'PROBLEM.SUPER_RESOLUTION.UPSCALING' variable")
assert all( i > 0 for i in cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING), "PROBLEM.SUPER_RESOLUTION.UPSCALING are not positive integers"
assert all( i > 0 for i in cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING), "'PROBLEM.SUPER_RESOLUTION.UPSCALING' are not positive integers"
if len(cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING) != dim_count:
raise ValueError(f"'PROBLEM.SUPER_RESOLUTION.UPSCALING' needs to be a tuple of {dim_count} integers")
if cfg.MODEL.SOURCE == "torchvision":
raise ValueError("'MODEL.SOURCE' as 'torchvision' is not available in super-resolution workflow")
if cfg.DATA.NORMALIZATION.TYPE != "div":
Expand Down Expand Up @@ -552,6 +554,10 @@ def check_configuration(cfg, jobname, check_data_paths=True):
elif len(cfg.MODEL.FEATURE_MAPS)-1 != len(cfg.MODEL.Z_DOWN):
raise ValueError("'MODEL.FEATURE_MAPS' length minus one and 'MODEL.Z_DOWN' length must be equal")

# Correct UPSCALING for other workflows than SR
if len(cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING) == 0:
opts.extend(['PROBLEM.SUPER_RESOLUTION.UPSCALING', (1,)*dim_count])

if len(opts) > 0:
cfg.merge_from_list(opts)

Expand Down

0 comments on commit a0f3ee1

Please sign in to comment.