From dc9d82253e3dfb5f23b902a3add53b057e5d47ed Mon Sep 17 00:00:00 2001 From: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Date: Wed, 19 Jul 2023 11:25:46 -0700 Subject: [PATCH] add adagrad for torch (#548) Co-authored-by: Haifeng Jin --- .../backend/torch/optimizers/torch_adagrad.py | 37 +++++++++++++++++++ .../torch/optimizers/torch_optimizer.py | 2 + keras_core/optimizers/adagrad_test.py | 9 +++-- 3 files changed, 44 insertions(+), 4 deletions(-) create mode 100644 keras_core/backend/torch/optimizers/torch_adagrad.py diff --git a/keras_core/backend/torch/optimizers/torch_adagrad.py b/keras_core/backend/torch/optimizers/torch_adagrad.py new file mode 100644 index 0000000000..117a93bc4c --- /dev/null +++ b/keras_core/backend/torch/optimizers/torch_adagrad.py @@ -0,0 +1,37 @@ +import torch + +from keras_core import ops +from keras_core import optimizers +from keras_core.backend.torch.optimizers import torch_parallel_optimizer + + +class Adagrad( + torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Adagrad +): + def _parallel_update_step( + self, + grads, + variables, + learning_rate, + ): + keras_variables = variables + variables = [v.value for v in variables] + + dtype = variables[0].dtype + lr = ops.cast(learning_rate, dtype) + + accumulators = [ + self._accumulators[self._get_variable_index(variable)].value + for variable in keras_variables + ] + torch._foreach_add_(accumulators, torch._foreach_mul(grads, grads)) + torch._foreach_add_( + variables, + torch._foreach_div( + torch._foreach_mul(grads, lr), + torch._foreach_sqrt( + torch._foreach_add(accumulators, self.epsilon) + ), + ), + alpha=-1, + ) diff --git a/keras_core/backend/torch/optimizers/torch_optimizer.py b/keras_core/backend/torch/optimizers/torch_optimizer.py index ca83cd84bd..57cb44ccb0 100644 --- a/keras_core/backend/torch/optimizers/torch_optimizer.py +++ b/keras_core/backend/torch/optimizers/torch_optimizer.py @@ -8,6 +8,7 @@ class TorchOptimizer(BaseOptimizer): def __new__(cls, *args, **kwargs): # Import locally to avoid circular imports. from keras_core.backend.torch.optimizers import torch_adadelta + from keras_core.backend.torch.optimizers import torch_adagrad from keras_core.backend.torch.optimizers import torch_adam from keras_core.backend.torch.optimizers import torch_adamw from keras_core.backend.torch.optimizers import torch_rmsprop @@ -15,6 +16,7 @@ def __new__(cls, *args, **kwargs): OPTIMIZERS = { optimizers.Adadelta: torch_adadelta.Adadelta, + optimizers.Adagrad: torch_adagrad.Adagrad, optimizers.Adam: torch_adam.Adam, optimizers.AdamW: torch_adamw.AdamW, optimizers.RMSprop: torch_rmsprop.RMSprop, diff --git a/keras_core/optimizers/adagrad_test.py b/keras_core/optimizers/adagrad_test.py index 3112f345b3..a531090294 100644 --- a/keras_core/optimizers/adagrad_test.py +++ b/keras_core/optimizers/adagrad_test.py @@ -4,6 +4,7 @@ import numpy as np from keras_core import backend +from keras_core import ops from keras_core import testing from keras_core.optimizers.adagrad import Adagrad @@ -19,7 +20,7 @@ def test_config(self): def test_single_step(self): optimizer = Adagrad(learning_rate=0.5) - grads = np.array([1.0, 6.0, 7.0, 2.0]) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) optimizer.apply_gradients(zip([grads], [vars])) self.assertAllClose( @@ -28,7 +29,7 @@ def test_single_step(self): def test_weight_decay(self): grads, var1, var2, var3 = ( - np.zeros(()), + ops.zeros(()), backend.Variable(2.0), backend.Variable(2.0, name="exclude"), backend.Variable(2.0), @@ -54,8 +55,8 @@ def test_correctness_with_golden(self): ) x = backend.Variable(np.ones([10])) - grads = np.arange(0.1, 1.1, 0.1) - first_grads = np.full((10,), 0.01) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) # fmt: off golden = np.array(