-
Notifications
You must be signed in to change notification settings - Fork 364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gy77/add freeze hook #1387
base: main
Are you sure you want to change the base?
gy77/add freeze hook #1387
Conversation
mmengine/hooks/freeze_hook.py
Outdated
unfreeze_epoch (int): The epoch number to start unfreezing layers. | ||
unfreeze_layers (tuple[str]): Model layers containing the keyword in | ||
unfreeze_layers will unfreeze the gradient. | ||
log_grad (bool): Whether to log the requires_grad of each layer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log_grad (bool): Whether to log the requires_grad of each layer. | |
verbose (bool): Whether to log the requires_grad of each layer. |
mmengine/hooks/freeze_hook.py
Outdated
|
||
Args: | ||
freeze_epoch (int): The epoch number to start freezing layers. | ||
freeze_layers (tuple[str]): Model layers containing the keyword in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest making freeze_layers
the first argument, and it should be a regex expression
mmengine/hooks/freeze_hook.py
Outdated
self.unfreeze_layers = unfreeze_layers | ||
self.log_grad = log_grad | ||
|
||
def modify_layers_grad(self, model, layers, requires_grad): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def modify_layers_grad(self, model, layers, requires_grad): | |
def _modify_layers_grad(self, model, layers, requires_grad): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the type hint
mmengine/hooks/freeze_hook.py
Outdated
v.requires_grad = requires_grad | ||
break | ||
|
||
def log_model_grad(self, model, log_grad=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def log_model_grad(self, model, log_grad=False): | |
def _log_model_grad(self, model, log_grad=False): |
mmengine/hooks/freeze_hook.py
Outdated
|
||
def __init__( | ||
self, | ||
freeze_layers: Union[Sequence[str], str], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it has been a regex expression, it is not necessary to make it a tuple of str ('exp1|exp2|exp3'
is enough)
mmengine/hooks/freeze_hook.py
Outdated
(tuple, list)) and not isinstance(freeze_layers[0], str): | ||
raise TypeError( | ||
'`freeze_layers` must be a tuple or list of string') | ||
if not isinstance(freeze_iter, (int, type(None))): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not isinstance(freeze_iter, (int, type(None))): | |
if not isinstance(freeze_iter) and freeze_iter is not None: |
mmengine/hooks/freeze_hook.py
Outdated
if not isinstance(verbose, bool): | ||
raise TypeError('`verbose` must be a boolean') | ||
# check arguments value | ||
if freeze_iter and freeze_iter < 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if freeze_iter and freeze_iter < 0: | |
if freeze_iter is not None and freeze_iter < 0: |
mmengine/hooks/freeze_hook.py
Outdated
if freeze_iter and freeze_iter < 0: | ||
raise ValueError( | ||
'`freeze_iter` must be greater than or equal to 0') | ||
if freeze_epoch and freeze_epoch < 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if freeze_epoch and freeze_epoch < 0: | |
if freeze_epoch is not None and freeze_epoch < 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merge this check into:
if (freeze_iter is None) ^ (freeze_epoch is None):
raise ValueError(...)
if freeze_iter is not None and freeze_iter < 0:
raise ValueError(...)
if freeze_epoch is not None and freeze_epoch < 0:
raise ValueError(...)
mmengine/hooks/freeze_hook.py
Outdated
"""Modify the `requires_grad` of the specified layers. | ||
|
||
Args: | ||
model (BaseModel): a BaseModel of mmengine. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model (BaseModel): a BaseModel of mmengine. | |
model (BaseModel): A BaseModel of mmengine. |
mmengine/hooks/freeze_hook.py
Outdated
|
||
def _modify_layers_grad(self, model: BaseModel, layers: Sequence[str], | ||
requires_grad: bool): | ||
"""Modify the `requires_grad` of the specified layers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Modify the `requires_grad` of the specified layers. | |
"""Modify the ``requires_grad`` of the specified layers. |
mmengine/hooks/freeze_hook.py
Outdated
print_log( | ||
f'{k} requires_grad: {v.requires_grad}', logger='current') | ||
|
||
def _main(self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _main(self, | |
def _freeze(self, |
mmengine/hooks/freeze_hook.py
Outdated
if self.freeze_iter is not None: | ||
self._main(runner, runner.iter, self.freeze_iter, | ||
self.unfreeze_iter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.freeze_iter is not None: | |
self._main(runner, runner.iter, self.freeze_iter, | |
self.unfreeze_iter) | |
if self.freeze_iter is not None and runner.iter in (self.freeze_iter, self.unfreeze_iter): | |
self._freeze(runner.model) |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1387 +/- ##
=======================================
Coverage ? 70.43%
=======================================
Files ? 154
Lines ? 14368
Branches ? 2999
=======================================
Hits ? 10120
Misses ? 3773
Partials ? 475
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Motivation
Motivation:
Goal:
Modification
Add FreezeHook and FreezeHook unit tests.
Use cases
freeze_layers
are freeze beforefreeze_iter/freeze_epoch
starts.unfreeze_layers
are freeze beforeunfreeze_iter/unfreeze_epoch
starts.freeze_layers/unfreeze_layers
matches network layers via regular expressioniter/epoch
starts at 0, with epoch=0 for the first epoch.unfreeze_iter
,unfreeze_epoch
andunfreeze_layers
are optional. Iffreeze_epoch/freeze_iter
is not None,unfreeze_layers
must not be None.freeze_iter
andfreeze_epoch
can be set, as well asunfreeze_iter
andunfreeze_epoch
.verbose
parameter is used to determine whether to print therequires_grad
variable for each model layer.