Skip to content

Change output channels of transfer learning mmar models #611

Answered by yiheng-wang-nv
frmrz asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @frmrz , sorry for late reply.

Assume the actual Unet model you need has some layers that are different from the pretrained weights, you need to modify the state dict. For example, if you need a Unet that has 5 channels output, the code is like:

unet_model = load_from_mmar(
    item=mmar[RemoteMMARKeys.NAME],
    mmar_dir=root_dir,
    map_location=device,
    pretrained=True)

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=5,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
)

model_dict = model.state_dict()

pretrain_state_dict = unet_model.state_dict()

pretrain_state_dict = {
    k: v for k, v in pretrai…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@frmrz
Comment options

Answer selected by wyli
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants