Skip to content

Commit

Permalink
set the default to use set_to_none for clearing gradients in BF16 opt…
Browse files Browse the repository at this point in the history
…imizer. (microsoft#5434)

as discussed in microsoft#5175, set the default to use set_to_none for clearing
gradients in BF16 optimizer.
Additionally, for the case of zero clearing, use foreach_zero.
Verified correctness with mega-ds llama 7B training.

FYI @loadams

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
2 people authored and dbyoung18 committed Jun 11, 2024
1 parent f113ab0 commit 98d94fb
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads):

# clear gradients
if clear_lp_grads:
lp.grad._zero()
lp.grad.zero_()

@torch.no_grad()
def _update_hp_grads_func(self, clear_lp_grads=False):
Expand Down Expand Up @@ -441,11 +441,20 @@ def clear_hp_grads(self):
self.fp32_groups_has_gradients[i] = [False] * len(group)

def clear_lp_grads(self):

# using zero_() fixed memory address for graph replay
set_to_none = False if self.graph_harvesting else True
zero_grads_list = []
for group in self.bf16_groups:
for param in group:
if param.grad is not None:
# Using zero_() fixed memory address for graph replay
param.grad.zero_()
if set_to_none:
param.grad = None
elif param.grad is not None:
if param.grad.grad_fn is not None:
param.grad.detach_()
zero_grads_list.append(param.grad)
if not set_to_none and len(zero_grads_list) > 0:
torch._foreach_zero_(zero_grads_list)

def state_dict(self):
state_dict = {}
Expand Down

0 comments on commit 98d94fb

Please sign in to comment.