From 6cc41736323307649b669dfcb7ee2dc8fca35685 Mon Sep 17 00:00:00 2001 From: Huilin Qu Date: Tue, 7 May 2024 14:12:52 +0200 Subject: [PATCH] Add an option to freeze certain model weights. --- weaver/train.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/weaver/train.py b/weaver/train.py index e0bd0ec6..56e75e1f 100644 --- a/weaver/train.py +++ b/weaver/train.py @@ -80,6 +80,8 @@ help='initialize model with pre-trained weights') parser.add_argument('--exclude-model-weights', type=str, default=None, help='comma-separated regex to exclude matched weights from being loaded, e.g., `a.fc..+,b.fc..+`') +parser.add_argument('--freeze-model-weights', type=str, default=None, + help='comma-separated regex to freeze matched weights from being updated in the training, e.g., `a.fc..+,b.fc..+`') parser.add_argument('--num-epochs', type=int, default=20, help='number of epochs') parser.add_argument('--steps-per-epoch', type=int, default=None, @@ -581,6 +583,19 @@ def model_setup(args, data_config, device='cpu'): missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False) _logger.info('Model initialized with weights from %s\n ... Missing: %s\n ... Unexpected: %s' % (args.load_model_weights, missing_keys, unexpected_keys)) + if args.freeze_model_weights: + import re + freeze_patterns = args.freeze_model_weights.split(',') + for name, param in model.named_parameters(): + freeze = False + for pattern in freeze_patterns: + if re.match(pattern, name): + freeze = True + break + if freeze: + param.requires_grad = False + _logger.info('The following weights has been frozen:\n - %s', + '\n - '.join([name for name, p in model.named_parameters() if not p.requires_grad])) # _logger.info(model) flops(model, model_info, device=device) # loss function