Skip to content

Commit

Permalink
Merge branch 'master' of github.com:BiaPyX/BiaPy
Browse files Browse the repository at this point in the history
  • Loading branch information
danifranco committed Feb 26, 2024
2 parents 942b3ad + 37f4bc5 commit 075ddfc
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions biapy/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def build_model(cfg, job_identifier, device):
model = EDSR(ndim=ndim, num_filters=64, num_of_residual_blocks=16, upsampling_factor=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING,
num_channels=cfg.DATA.PATCH_SIZE[-1])
elif modelname == 'rcan':
model = rcan(ndim=ndim, filters=16, n_sub_block=int(np.log2(cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING)), num_channels=cfg.DATA.PATCH_SIZE[-1])
scale = cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING
if type(scale) is tuple:
scale = scale[0]
model = rcan(ndim=ndim, filters=16, scale=scale, n_sub_block=int(np.log2(scale)), num_channels=cfg.DATA.PATCH_SIZE[-1])
elif modelname == 'dfcan':
model = DFCAN(ndim=ndim, input_shape=cfg.DATA.PATCH_SIZE, scale=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, n_ResGroup = 4, n_RCAB = 4)
elif modelname == 'wdsr':
Expand Down Expand Up @@ -212,4 +215,4 @@ def build_torchvision_model(cfg, device):
summary(model, input_size=sample_size, col_names=("input_size", "output_size", "num_params"), depth=10,
device="cpu" if "cuda" not in device.type else "cuda")

return model, model_torchvision_weights.transforms()
return model, model_torchvision_weights.transforms()

0 comments on commit 075ddfc

Please sign in to comment.