Skip to content

Commit

Permalink
fix it
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 24, 2024
1 parent 13d97b0 commit b3e2dd8
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions open_diloco/train_pure_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,15 @@ def train(config: Config):

mask = torch.rand_like(param_offloaded.grad) > 0.95

data_to_all_reduce = param_offloaded.grad * mask
data_to_send = param_offloaded.grad * mask
data_to_send_pre_reduce = data_to_send.clone()

# gloo does not support AVG
data_to_all_reduce = data_to_all_reduce / global_pg.size()
dist.all_reduce(data_to_all_reduce, op=dist.ReduceOp.SUM, group=global_pg)

param_offloaded.grad += data_to_all_reduce
data_to_send = data_to_send / global_pg.size()
dist.all_reduce(data_to_send, op=dist.ReduceOp.SUM, group=global_pg)

param_offloaded.grad += data_to_send - data_to_send_pre_reduce # removing the

outer_optimizer.step()
outer_optimizer.zero_grad()

Expand Down

0 comments on commit b3e2dd8

Please sign in to comment.