Skip to content

Commit

Permalink
update: FAdam optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Jun 2, 2024
1 parent f7c5ec0 commit b8938e1
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions pytorch_optimizer/optimizer/fadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
momentum, fim = state['momentum'], state['fim']
fim.mul_(curr_beta2).addcmul_(grad, grad, value=1.0 - curr_beta2)

rms_grad = torch.pow(grad, 2).mean().sqrt_()
rms_grad = grad.pow(2).mean().sqrt_()
curr_eps = min(rms_grad, 1) * group['eps']

fim_base = torch.pow(fim, group['p']).add_(curr_eps)
grad_nat = torch.div(grad, fim_base)
fim_base = fim.pow(group['p']).add_(curr_eps)
grad_nat = grad / fim_base

rms = torch.pow(grad_nat, 2).mean().sqrt_()
rms = grad_nat.pow(2).mean().sqrt_()
divisor = max(1, rms) / group['clip']
grad_nat.div_(divisor)

Expand All @@ -119,6 +119,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:

grad_weights.mul_(group['weight_decay']).add_(momentum)

p.add_(-grad_weights, alpha=group['lr'])
p.add_(grad_weights, alpha=-group['lr'])

return loss

0 comments on commit b8938e1

Please sign in to comment.