Skip to content

Commit

Permalink
Update main_finetune_chestxray.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lambert-x authored Feb 20, 2024
1 parent b0b32f1 commit 5744fa5
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions main_finetune_chestxray.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,16 +286,16 @@ def main(args):
print("Load pre-trained checkpoint from: %s" % args.finetune)
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
if args.global_pool:
for k in ['fc_norm.weight', 'fc_norm.bias']:
try:
del checkpoint_model[k]
except:
pass
for k in checkpoint_model.keys():
if k in state_dict:
if checkpoint_model[k].shape == state_dict[k].shape:
state_dict[k] = checkpoint_model[k]
print(f"Loaded Index: {k} from Saved Weights")
else:
print(f"Shape of {k} doesn't match with {state_dict[k]}")
else:
print(f"{k} not found in Init Model")



# interpolate position embedding
Expand Down

0 comments on commit 5744fa5

Please sign in to comment.