Skip to content

Commit

Permalink
Add sum gradient reduction in tf.distribute case (#19467)
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet authored Apr 9, 2024
1 parent ec5eadf commit 8961e3f
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions keras/backend/tensorflow/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import tensorflow as tf

from keras import backend
Expand Down Expand Up @@ -124,6 +126,8 @@ def _backend_update_step(self, grads, trainable_variables, learning_rate):
def _distributed_tf_update_step(
self, distribution, grads_and_vars, learning_rate
):
grads_and_vars = self._all_reduce_sum_gradients(grads_and_vars)

def apply_grad_to_update_var(var, grad):
return self.update_step(grad, var, learning_rate)

Expand All @@ -132,12 +136,48 @@ def apply_grad_to_update_var(var, grad):
var, apply_grad_to_update_var, args=(grad,), group=False
)

def _all_reduce_sum_gradients(self, grads_and_vars):
"""Returns all-reduced gradients aggregated via summation.
Args:
grads_and_vars: List of (gradient, variable) pairs.
Returns:
List of (gradient, variable) pairs
where gradients have been all-reduced.
"""
replica_context = tf.distribute.get_replica_context()
if not replica_context:
return grads_and_vars

grads_and_vars = list(grads_and_vars)
filtered_grads_and_vars = filter_empty_gradients(grads_and_vars)
if filtered_grads_and_vars:
grads = [pair[0] for pair in filtered_grads_and_vars]
reduced = tf.distribute.get_replica_context().all_reduce(
tf.distribute.ReduceOp.SUM, grads
)
else:
reduced = []
# Copy 'reduced' but add None gradients back in
reduced_with_nones = []
reduced_pos = 0
for g, v in grads_and_vars:
if g is None:
reduced_with_nones.append((None, v))
else:
reduced_with_nones.append((reduced[reduced_pos], v))
reduced_pos += 1
assert reduced_pos == len(reduced), "Failed to add all gradients"
return reduced_with_nones

def _overwrite_model_variables_with_average_value(
self, trainable_variables
):
"""Overwrite model variables with their moving average values.
This function overwrites variables on each device.
Args:
var_list: list of model variables.
"""
Expand Down Expand Up @@ -178,3 +218,34 @@ def _clip_by_norm(self, values, axes=None):
# We need to use TF-specific OP to support the case,
# when `values` are `tf.IndexedSlices`.
return tf.clip_by_norm(values, self.clipnorm, axes)


def filter_empty_gradients(grads_and_vars):
"""Filter out `(grad, var)` pairs that have a gradient equal to `None`."""
grads_and_vars = tuple(grads_and_vars)
if not grads_and_vars:
return grads_and_vars

filtered = []
vars_with_empty_grads = []
for grad, var in grads_and_vars:
if grad is None:
vars_with_empty_grads.append(var)
else:
filtered.append((grad, var))
filtered = tuple(filtered)

if not filtered:
variable = ([v.name for _, v in grads_and_vars],)
raise ValueError(
f"No gradients provided for any variable: {variable}. "
f"Provided `grads_and_vars` is {grads_and_vars}."
)
if vars_with_empty_grads:
warnings.warn(
"Gradients do not exist for variables %s when minimizing the "
"loss. If you're using `model.compile()`, did you forget to "
"provide a `loss` argument?",
([v.name for v in vars_with_empty_grads]),
)
return filtered

0 comments on commit 8961e3f

Please sign in to comment.