Skip to content

Commit

Permalink
Merge pull request #158 from kozistr/update/dadaptation
Browse files Browse the repository at this point in the history
[Update] D-Adaptation v3
  • Loading branch information
kozistr authored May 6, 2023
2 parents 9110bbd + ddf81a7 commit ed97a2f
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 115 deletions.
4 changes: 2 additions & 2 deletions pytorch_optimizer/optimizer/adashift.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AdaShift(Optimizer, BaseOptimizer):
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param keep_num: int. number of gradients used to compute first moment estimation.
:param keep_num: int. number of gradients used to compute first moment estimation.
:param reduce_func: Optional[Callable]. function applied to squared gradients to further reduce the correlation.
If None, no function is applied.
:param eps: float. term added to the denominator to improve numerical stability.
Expand Down Expand Up @@ -69,7 +69,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:

beta1, beta2 = group['betas']

exp_weight_sum: int = sum(beta1**i for i in range(group['keep_num']))
exp_weight_sum: int = sum(beta1 ** i for i in range(group['keep_num'])) # fmt: skip
first_grad_weight: float = beta1 ** (group['keep_num'] - 1) / exp_weight_sum
last_grad_weight: float = 1.0 / exp_weight_sum

Expand Down
Loading

0 comments on commit ed97a2f

Please sign in to comment.