diff --git a/pytorch_optimizer/optimizer/fadam.py b/pytorch_optimizer/optimizer/fadam.py index d0474bc4a..26ee108f5 100644 --- a/pytorch_optimizer/optimizer/fadam.py +++ b/pytorch_optimizer/optimizer/fadam.py @@ -99,13 +99,13 @@ def step(self, closure: CLOSURE = None) -> LOSS: momentum, fim = state['momentum'], state['fim'] fim.mul_(curr_beta2).addcmul_(grad, grad, value=1.0 - curr_beta2) - rms_grad = torch.pow(grad, 2).mean().sqrt_() + rms_grad = grad.pow(2).mean().sqrt_() curr_eps = min(rms_grad, 1) * group['eps'] - fim_base = torch.pow(fim, group['p']).add_(curr_eps) - grad_nat = torch.div(grad, fim_base) + fim_base = fim.pow(group['p']).add_(curr_eps) + grad_nat = grad / fim_base - rms = torch.pow(grad_nat, 2).mean().sqrt_() + rms = grad_nat.pow(2).mean().sqrt_() divisor = max(1, rms) / group['clip'] grad_nat.div_(divisor) @@ -119,6 +119,6 @@ def step(self, closure: CLOSURE = None) -> LOSS: grad_weights.mul_(group['weight_decay']).add_(momentum) - p.add_(-grad_weights, alpha=group['lr']) + p.add_(grad_weights, alpha=-group['lr']) return loss