diff --git a/biapy/models/__init__.py b/biapy/models/__init__.py index 73967cac..f56c97d1 100644 --- a/biapy/models/__init__.py +++ b/biapy/models/__init__.py @@ -101,7 +101,10 @@ def build_model(cfg, job_identifier, device): model = EDSR(ndim=ndim, num_filters=64, num_of_residual_blocks=16, upsampling_factor=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, num_channels=cfg.DATA.PATCH_SIZE[-1]) elif modelname == 'rcan': - model = rcan(ndim=ndim, filters=16, n_sub_block=int(np.log2(cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING)), num_channels=cfg.DATA.PATCH_SIZE[-1]) + scale = cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING + if type(scale) is tuple: + scale = scale[0] + model = rcan(ndim=ndim, filters=16, scale=scale, n_sub_block=int(np.log2(scale)), num_channels=cfg.DATA.PATCH_SIZE[-1]) elif modelname == 'dfcan': model = DFCAN(ndim=ndim, input_shape=cfg.DATA.PATCH_SIZE, scale=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, n_ResGroup = 4, n_RCAB = 4) elif modelname == 'wdsr': @@ -212,4 +215,4 @@ def build_torchvision_model(cfg, device): summary(model, input_size=sample_size, col_names=("input_size", "output_size", "num_params"), depth=10, device="cpu" if "cuda" not in device.type else "cuda") - return model, model_torchvision_weights.transforms() \ No newline at end of file + return model, model_torchvision_weights.transforms()