Skip to content

Commit

Permalink
local_sgd: initial version of fault tolerant LocalSGD
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Dec 18, 2024
1 parent a484e4f commit 4bd104e
Show file tree
Hide file tree
Showing 8 changed files with 316 additions and 39 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ the entire training job.
manager
optim
ddp
local_sgd
data
checkpointing
parameter_server
Expand Down
4 changes: 4 additions & 0 deletions docs/source/local_sgd.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.. automodule:: torchft.local_sgd
:members:
:undoc-members:
:show-inheritance:
4 changes: 2 additions & 2 deletions torchft/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
def _comm_hook(
state: "Manager", bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
return state.allreduce_grad(bucket.buffer())
return state.allreduce(bucket.buffer())


class PureDistributedDataParallel(nn.Module):
Expand All @@ -88,7 +88,7 @@ def __init__(self, manager: "Manager", module: nn.Module) -> None:

def post_grad_hook(p: torch.Tensor) -> None:
if p.grad is not None:
manager.allreduce_grad(p.grad)
manager.allreduce(p.grad)

for p in module.parameters():
p.register_post_accumulate_grad_hook(post_grad_hook)
Expand Down
6 changes: 3 additions & 3 deletions torchft/ddp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def test_pure_ddp(self) -> None:
for p in m.parameters():
self.assertIsNotNone(p.grad)

self.assertEqual(manager.allreduce_grad.call_count, len(list(m.parameters())))
self.assertEqual(manager.allreduce.call_count, len(list(m.parameters())))

def test_ddp(self) -> None:
manager = create_autospec(Manager)

call_count = 0

def allreduce_grad(tensor: torch.Tensor) -> Future[torch.Tensor]:
def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]:
nonlocal call_count

call_count += 1
Expand All @@ -48,7 +48,7 @@ def allreduce_grad(tensor: torch.Tensor) -> Future[torch.Tensor]:
fut.set_result(tensor)
return fut

manager.allreduce_grad = allreduce_grad
manager.allreduce = allreduce

m = nn.Linear(3, 4)
m = DistributedDataParallel(manager, m)
Expand Down
177 changes: 177 additions & 0 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
LocalSGD
=========
This module implements a fault tolerant version of LocalSGD and related methods.
"""

from typing import Any, Dict, List, Mapping, Optional

import torch
from torch import nn, optim

from torchft.manager import Manager


class LocalSGD(nn.Module):
"""
LocalSGD is a model wrapper similar to DistributedDataParallel that
implements the algorithm described in https://arxiv.org/pdf/1805.09767
This will synchronize the model parameters periodically in a fault tolerant
way using a torchft Manager. The allreduce on the parameters will happen
every sync_every steps after the optimizer.step call.
To implement safe and fault tolerant, this requires a backup copy of the
weights. By default these are stored in CPU memory. If any error occurs
during the LocalSGD step, the step will be discarded and the model
parameters will reset back to the last time LocalSGD synchronized.
The backup weights could be eliminated by relaxing the guarantee of exactly
`sync_every` steps but that would diverge from the LocalSGD algorithm.
DiLoCo also needs this backup copy to compute the delta.
The torchft quorum is computed at the beginning of ``sync_every`` steps. If
any error occurs, or a worker fails between syncs, ``sync_every`` steps will be
discarded and a new quorum will be computed on the next step.
If running in async mode, on a joining worker the first ``sync_every`` steps
will discarded as the model will be recovering during that period. When
using sync mode, the checkpoint will be restored prior to the first step.
TODO: add a way via Manager to detect workers failing early for shrink only
TODO: add DiLoCo support
"""

def __init__(
self,
manager: Manager,
model: nn.Module,
optimizer: optim.Optimizer,
sync_every: int,
backup_device: Optional[torch.device] = None,
) -> None:
"""
Args:
manager: The manager to use.
model: The model to wrap.
sync_every: How often to sync the model weights.
backup_device: The device to store the backup of the model parameters on. (default cpu)
"""
super().__init__()

self._manager = manager
self._model = model
self._local_step = 0
self._started_step = False
self._sync_every = sync_every
assert sync_every >= 1, "sync_every must be greater than or equal to 1"

device = backup_device or torch.device("cpu")

self._backup_parameters: Dict[str, torch.Tensor] = {}

for name, p in self._model.named_parameters():
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device)
if t.device == torch.device("cpu"):
t = t.pin_memory()
self._backup_parameters[name] = t

# Need to copy the parameters to the host to be safe if we are on the first step.
self._save_parameters()

optimizer.register_step_post_hook(self._step_post_hook)

def _save_parameters(self) -> None:
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
self._backup_parameters[name].copy_(p.data, non_blocking=True)

def _restore_parameters(self) -> None:
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
p.data.copy_(self._backup_parameters[name], non_blocking=True)

# pyre-fixme[14]: support state_dict args
def state_dict(self) -> Dict[str, object]:
"""
state_dict returns the state_dict from the last time LocalSGD
synchronized and not the current weights.
"""
state_dict = self._model.state_dict()
for name, p in self._backup_parameters.items():
assert name in state_dict
state_dict[name] = p
return state_dict

def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
) -> None:
"""
Loads the state dict to the model and the backup parameters.
This must be called while the model weights aren't being modified to
avoid corrupting the backup weights.
"""
self._model.load_state_dict(state_dict, strict=strict, assign=assign)
self._save_parameters()

def forward(self, *args: object, **kwargs: object) -> object:
"""
Run the model parameters.
This should be called before the optimizer step.
This will start the quorum and save the parameters if this is the first step.
"""
if self._local_step == 0:
self._manager.start_quorum()

self._started_step = True

return self._model.forward(*args, **kwargs)

def _step_post_hook(
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
) -> None:
"""
This hook is registered on the optimizer and is called after the optimizer step.
This will call the allreduce on the model weights every sync_every steps.
If any errors occur it will restore to the weights from the previous sync.
``forward`` must be called before this function.
"""
assert self._started_step, "forward must be called before step"
self._started_step = False

self._local_step += 1

if self._local_step >= self._sync_every:
self._local_step = 0
self._average()

if self._manager.should_commit():
# save the parameters so we can restore from them later if necessary.
self._save_parameters()
else:
# commit failed, restore from the backup parameters
self._restore_parameters()

def _average(self) -> None:
# TODO: do we need to broadcast buffers like DDP does?

works = []

for p in self._model.parameters():
# TODO: bucketize parameters
works.append(self._manager.allreduce(p))

for work in works:
work.wait()
96 changes: 96 additions & 0 deletions torchft/local_sgd_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict
from unittest import TestCase
from unittest.mock import create_autospec

import torch
from torch import nn, optim

from torchft.local_sgd import LocalSGD
from torchft.manager import Manager


class SimpleModel(nn.Module):
def __init__(self) -> None:
super().__init__()

self.model = nn.Sequential(
nn.Linear(3, 4),
nn.ReLU(),
nn.Linear(4, 5),
nn.Sigmoid(),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)


def _params_dict(m: torch.nn.Module) -> Dict[str, torch.Tensor]:
return {name: p.data for name, p in m.named_parameters()}


def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return {name: value.clone().detach() for name, value in state_dict.items()}


class LocalSGDTest(TestCase):
def test_local_sgd_healthy(self) -> None:
base_m = SimpleModel()
optimizer = optim.SGD(base_m.parameters())
manager = create_autospec(Manager)

m = LocalSGD(manager, base_m, optimizer, sync_every=2)
self.assertEqual(m._local_step, 0)

torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))

inp = torch.rand(2, 3)

loss = m(inp).mean()
loss.backward()
optimizer.step()

self.assertEqual(m._local_step, 1)
self.assertEqual(manager.start_quorum.call_count, 1)

loss = m(inp).mean()
loss.backward()
optimizer.step()

manager.should_commit.return_value = True
self.assertEqual(m._local_step, 0)

torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
self.assertEqual(manager.should_commit.call_count, 1)
self.assertEqual(manager.allreduce.call_count, 4)

def test_local_sgd_recovery(self) -> None:
base_m = SimpleModel()
optimizer = optim.SGD(base_m.parameters())
manager = create_autospec(Manager)

m = LocalSGD(manager, base_m, optimizer, sync_every=2)

torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
og_state_dict = _copy_state_dict(base_m.state_dict())

inp = torch.rand(2, 3)

loss = m(inp).mean()
loss.backward()
optimizer.step()

self.assertEqual(m._local_step, 1)

state_dict = m.state_dict()
torch.testing.assert_close(state_dict, m._backup_parameters)
torch.testing.assert_close(state_dict, og_state_dict)

m.load_state_dict(state_dict)
torch.testing.assert_close(_params_dict(base_m), state_dict)
torch.testing.assert_close(m._backup_parameters, _params_dict(base_m))
Loading

0 comments on commit 4bd104e

Please sign in to comment.