Skip to content

Commit

Permalink
Add an option to freeze certain model weights.
Browse files Browse the repository at this point in the history
  • Loading branch information
hqucms committed May 7, 2024
1 parent 063f41a commit 6cc4173
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions weaver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@
help='initialize model with pre-trained weights')
parser.add_argument('--exclude-model-weights', type=str, default=None,
help='comma-separated regex to exclude matched weights from being loaded, e.g., `a.fc..+,b.fc..+`')
parser.add_argument('--freeze-model-weights', type=str, default=None,
help='comma-separated regex to freeze matched weights from being updated in the training, e.g., `a.fc..+,b.fc..+`')
parser.add_argument('--num-epochs', type=int, default=20,
help='number of epochs')
parser.add_argument('--steps-per-epoch', type=int, default=None,
Expand Down Expand Up @@ -581,6 +583,19 @@ def model_setup(args, data_config, device='cpu'):
missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)
_logger.info('Model initialized with weights from %s\n ... Missing: %s\n ... Unexpected: %s' %
(args.load_model_weights, missing_keys, unexpected_keys))
if args.freeze_model_weights:
import re
freeze_patterns = args.freeze_model_weights.split(',')
for name, param in model.named_parameters():
freeze = False
for pattern in freeze_patterns:
if re.match(pattern, name):
freeze = True
break
if freeze:
param.requires_grad = False
_logger.info('The following weights has been frozen:\n - %s',
'\n - '.join([name for name, p in model.named_parameters() if not p.requires_grad]))
# _logger.info(model)
flops(model, model_info, device=device)
# loss function
Expand Down

0 comments on commit 6cc4173

Please sign in to comment.