diff --git a/torch_optimizer/lamb.py b/torch_optimizer/lamb.py index 4ac5fbb..a1da0ef 100644 --- a/torch_optimizer/lamb.py +++ b/torch_optimizer/lamb.py @@ -28,6 +28,7 @@ class Lamb(Optimizer): adam: always use trust ratio = 1, which turns this into Adam. Useful for comparison purposes. (default: False) debias: debias adam by (1 - beta**step) (default: False) + prenorm: if True, perform pre-normalization of all gradients Example: >>> import torch_optimizer as optim @@ -52,6 +53,7 @@ def __init__( clamp_value: float = 10, adam: bool = False, debias: bool = False, + prenorm: bool = False, ) -> None: if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) @@ -76,6 +78,7 @@ def __init__( self.clamp_value = clamp_value self.adam = adam self.debias = debias + self.prenorm = prenorm super(Lamb, self).__init__(params, defaults) @@ -89,6 +92,23 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: if closure is not None: loss = closure() + if self.prenorm: + norm_sq = 0.0 + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + norm_sq += torch.linalg.norm(p.grad).item()**2 + + norm = math.sqrt(norm_sq) + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + p.grad /= norm + for group in self.param_groups: for p in group['params']: if p.grad is None: