From e79f45839542c9528c0b630813b7909801662d24 Mon Sep 17 00:00:00 2001 From: Ignacio Arganda-Carreras Date: Mon, 26 Feb 2024 12:02:17 +0100 Subject: [PATCH 1/2] Correct 2d generator for super-resolution --- biapy/data/generators/pair_data_2D_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/biapy/data/generators/pair_data_2D_generator.py b/biapy/data/generators/pair_data_2D_generator.py index c4174bd5..00e600c9 100644 --- a/biapy/data/generators/pair_data_2D_generator.py +++ b/biapy/data/generators/pair_data_2D_generator.py @@ -31,7 +31,7 @@ def ensure_shape(self, img, mask): # Super-resolution check. if random_crops_in_DA is activated the images have not been cropped yet, # so this check can not be done and it will be done in the random crop if not self.random_crops_in_DA and self.Y_provided and self.random_crop_scale != 1: - s = [img.shape[0]*self.random_crop_scale, img.shape[1]*self.random_crop_scale] + s = [img.shape[0]*self.random_crop_scale[0], img.shape[1]*self.random_crop_scale[1]] if all(x!=y for x,y in zip(s,mask.shape[:-1])): raise ValueError("Images loaded need to be LR and its HR version. LR shape:" " {} vs HR shape {} is not x{} larger".format(img.shape[:-1], mask.shape[:-1], self.random_crop_scale)) From 11a3f5204e25aa9bd45ac907952236468965d2ac Mon Sep 17 00:00:00 2001 From: Ignacio Arganda-Carreras Date: Mon, 26 Feb 2024 12:02:48 +0100 Subject: [PATCH 2/2] Make sure RCAN receives scale as a int --- biapy/models/__init__.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/biapy/models/__init__.py b/biapy/models/__init__.py index 73967cac..f56c97d1 100644 --- a/biapy/models/__init__.py +++ b/biapy/models/__init__.py @@ -101,7 +101,10 @@ def build_model(cfg, job_identifier, device): model = EDSR(ndim=ndim, num_filters=64, num_of_residual_blocks=16, upsampling_factor=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, num_channels=cfg.DATA.PATCH_SIZE[-1]) elif modelname == 'rcan': - model = rcan(ndim=ndim, filters=16, n_sub_block=int(np.log2(cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING)), num_channels=cfg.DATA.PATCH_SIZE[-1]) + scale = cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING + if type(scale) is tuple: + scale = scale[0] + model = rcan(ndim=ndim, filters=16, scale=scale, n_sub_block=int(np.log2(scale)), num_channels=cfg.DATA.PATCH_SIZE[-1]) elif modelname == 'dfcan': model = DFCAN(ndim=ndim, input_shape=cfg.DATA.PATCH_SIZE, scale=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, n_ResGroup = 4, n_RCAB = 4) elif modelname == 'wdsr': @@ -212,4 +215,4 @@ def build_torchvision_model(cfg, device): summary(model, input_size=sample_size, col_names=("input_size", "output_size", "num_params"), depth=10, device="cpu" if "cuda" not in device.type else "cuda") - return model, model_torchvision_weights.transforms() \ No newline at end of file + return model, model_torchvision_weights.transforms()