Skip to content

Commit

Permalink
Merge pull request #159 from kozistr/fix/bias-correction
Browse files Browse the repository at this point in the history
[Fix] bias correction in D-Adaptation Adam v3 optimizer
  • Loading branch information
kozistr authored May 6, 2023
2 parents ed97a2f + f5bf44e commit 2e97e5f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
16 changes: 11 additions & 5 deletions pytorch_optimizer/optimizer/dadapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class DAdaptAdam(Optimizer, BaseOptimizer):
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. use AdamW style weight decay.
:param fixed_decay: bool. fix weight decay.
:param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
:param bias_correction: bool. Turn on Adam's bias correction.
:param eps: float. term added to the denominator to improve numerical stability.
"""

Expand All @@ -265,7 +265,7 @@ def __init__(
weight_decay: float = 0.0,
weight_decouple: bool = False,
fixed_decay: bool = False,
adam_debias: bool = False,
bias_correction: bool = False,
eps: float = 0.0,
):
self.validate_learning_rate(lr)
Expand All @@ -281,7 +281,7 @@ def __init__(
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'fixed_decay': fixed_decay,
'adam_debias': adam_debias,
'bias_correction': bias_correction,
'step': 0,
'eps': eps,
}
Expand Down Expand Up @@ -321,8 +321,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
d: float = group['d']
lr: float = group['lr']

bias_correction: float = 1.0 - pow(beta1, group['step'] + 1)
d_lr: float = self.apply_adam_debias(group['adam_debias'], step_size=d * lr, bias_correction1=bias_correction)
bias_correction1: float = 1.0 - beta1 ** (group['step'] + 1)
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** (group['step'] + 1))
bias_correction: float = bias_correction1 / bias_correction2_sq

# it's not Adam Debias
d_lr: float = self.apply_adam_debias(
group['bias_correction'], step_size=d * lr, bias_correction1=bias_correction
)

sk_l1 = torch.tensor([0.0], device=device)
numerator_acc = torch.tensor([0.0], device=device)
Expand Down
10 changes: 5 additions & 5 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,12 @@
(Adan, {'lr': 5e-1, 'max_grad_norm': 1.0}, 5),
(Adan, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 5),
(Adan, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': True}, 5),
(DAdaptAdaGrad, {'lr': 2e0, 'weight_decay': 1e-3}, 50),
(DAdaptAdaGrad, {'lr': 2e0, 'weight_decay': 1e-3, 'momentum': 0.1}, 50),
(DAdaptAdam, {'lr': 5e2, 'weight_decay': 1e-3}, 25),
(DAdaptAdaGrad, {'lr': 3e0, 'weight_decay': 1e-3}, 30),
(DAdaptAdaGrad, {'lr': 5e0, 'weight_decay': 1e-3, 'momentum': 0.1}, 20),
(DAdaptAdam, {'lr': 5e4, 'weight_decay': 1e-1}, 10),
(DAdaptSGD, {'lr': 2e0, 'weight_decay': 1e-3}, 25),
(DAdaptAdan, {'lr': 1e0, 'weight_decay': 1e-2}, 25),
(DAdaptAdan, {'lr': 1e0, 'weight_decay': 1e-2, 'weight_decouple': True}, 50),
(DAdaptAdan, {'lr': 2e0, 'weight_decay': 1e-3}, 20),
(DAdaptAdan, {'lr': 2e0, 'weight_decay': 1e-3, 'weight_decouple': True}, 20),
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 20),
(AdaFactor, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100),
Expand Down

0 comments on commit 2e97e5f

Please sign in to comment.