Skip to content

Commit

Permalink
add smth special maybe ?
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 24, 2024
1 parent 71a9f89 commit 13d97b0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
Empty file modified open_diloco/simulate_multi_node.sh
100644 → 100755
Empty file.
15 changes: 9 additions & 6 deletions open_diloco/train_pure_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,16 @@ def train(config: Config):
for param_offloaded, param in zip(cpu_model, model.parameters()):
# todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices
param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device)

mask = torch.rand_like(param_offloaded.grad) > 0.95

data_to_all_reduce = param_offloaded.grad * mask

if param_offloaded.grad.device == torch.device("cpu"):
# gloo does not support AVG
param_offloaded.grad = param_offloaded.grad / global_pg.size()
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=global_pg)
else:
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg)
# gloo does not support AVG
data_to_all_reduce = data_to_all_reduce / global_pg.size()
dist.all_reduce(data_to_all_reduce, op=dist.ReduceOp.SUM, group=global_pg)

param_offloaded.grad += data_to_all_reduce

outer_optimizer.step()
outer_optimizer.zero_grad()
Expand Down

0 comments on commit 13d97b0

Please sign in to comment.