Skip to content

Commit

Permalink
add adagrad for torch (keras-team#548)
Browse files Browse the repository at this point in the history
Co-authored-by: Haifeng Jin <[email protected]>
  • Loading branch information
haifeng-jin and haifeng-jin authored Jul 19, 2023
1 parent 0bf7325 commit dc9d822
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
37 changes: 37 additions & 0 deletions keras_core/backend/torch/optimizers/torch_adagrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch

from keras_core import ops
from keras_core import optimizers
from keras_core.backend.torch.optimizers import torch_parallel_optimizer


class Adagrad(
torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Adagrad
):
def _parallel_update_step(
self,
grads,
variables,
learning_rate,
):
keras_variables = variables
variables = [v.value for v in variables]

dtype = variables[0].dtype
lr = ops.cast(learning_rate, dtype)

accumulators = [
self._accumulators[self._get_variable_index(variable)].value
for variable in keras_variables
]
torch._foreach_add_(accumulators, torch._foreach_mul(grads, grads))
torch._foreach_add_(
variables,
torch._foreach_div(
torch._foreach_mul(grads, lr),
torch._foreach_sqrt(
torch._foreach_add(accumulators, self.epsilon)
),
),
alpha=-1,
)
2 changes: 2 additions & 0 deletions keras_core/backend/torch/optimizers/torch_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ class TorchOptimizer(BaseOptimizer):
def __new__(cls, *args, **kwargs):
# Import locally to avoid circular imports.
from keras_core.backend.torch.optimizers import torch_adadelta
from keras_core.backend.torch.optimizers import torch_adagrad
from keras_core.backend.torch.optimizers import torch_adam
from keras_core.backend.torch.optimizers import torch_adamw
from keras_core.backend.torch.optimizers import torch_rmsprop
from keras_core.backend.torch.optimizers import torch_sgd

OPTIMIZERS = {
optimizers.Adadelta: torch_adadelta.Adadelta,
optimizers.Adagrad: torch_adagrad.Adagrad,
optimizers.Adam: torch_adam.Adam,
optimizers.AdamW: torch_adamw.AdamW,
optimizers.RMSprop: torch_rmsprop.RMSprop,
Expand Down
9 changes: 5 additions & 4 deletions keras_core/optimizers/adagrad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from keras_core import backend
from keras_core import ops
from keras_core import testing
from keras_core.optimizers.adagrad import Adagrad

Expand All @@ -19,7 +20,7 @@ def test_config(self):

def test_single_step(self):
optimizer = Adagrad(learning_rate=0.5)
grads = np.array([1.0, 6.0, 7.0, 2.0])
grads = ops.array([1.0, 6.0, 7.0, 2.0])
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
optimizer.apply_gradients(zip([grads], [vars]))
self.assertAllClose(
Expand All @@ -28,7 +29,7 @@ def test_single_step(self):

def test_weight_decay(self):
grads, var1, var2, var3 = (
np.zeros(()),
ops.zeros(()),
backend.Variable(2.0),
backend.Variable(2.0, name="exclude"),
backend.Variable(2.0),
Expand All @@ -54,8 +55,8 @@ def test_correctness_with_golden(self):
)

x = backend.Variable(np.ones([10]))
grads = np.arange(0.1, 1.1, 0.1)
first_grads = np.full((10,), 0.01)
grads = ops.arange(0.1, 1.1, 0.1)
first_grads = ops.full((10,), 0.01)

# fmt: off
golden = np.array(
Expand Down

0 comments on commit dc9d822

Please sign in to comment.