Skip to content

Commit

Permalink
Fix activation cpu offloading (#2724)
Browse files Browse the repository at this point in the history
* add cpu offload

* Update composer/trainer/dist_strategy.py

Co-authored-by: Mihir Patel <[email protected]>

* add version guard and fix cpu offload only

* fix import

* add unit test

* add version check

* inline import in test

* apply eval client fix

* apply eval client fix

* Update tests/trainer/test_fsdp_act_ckpt_offload.py

Co-authored-by: Mihir Patel <[email protected]>

* move test to test_fsdp

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
cli99 and mvpatel2000 authored Nov 22, 2023
1 parent a193115 commit 4dcbc2b
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 16 deletions.
47 changes: 33 additions & 14 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,21 +530,40 @@ def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_para

# Activation Checkpointing
if activation_checkpointing or activation_cpu_offload:
if not activation_checkpointing_reentrant:
first_wrap_fn = lambda m: checkpoint_wrapper(m, checkpoint_impl=CheckpointImpl.NO_REENTRANT
) if activation_checkpointing else (lambda module:
module)
second_wrap_fn = (
lambda module: checkpoint_wrapper(
first_wrap_fn(module), # type: ignore reportGeneralTypeIssues
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
offload_to_cpu=True)) if activation_cpu_offload else first_wrap_fn
if version.parse(torch.__version__) > version.parse('2.1.0.dev'):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper
if not activation_checkpointing_reentrant:
first_wrap_fn = lambda m: checkpoint_wrapper(m, checkpoint_impl=CheckpointImpl.NO_REENTRANT
) if activation_checkpointing else (lambda module:
module)
second_wrap_fn = (
lambda module: offload_wrapper(
first_wrap_fn(module)
if activation_checkpointing else module, # type: ignore reportGeneralTypeIssues
)) if activation_cpu_offload else first_wrap_fn
else:
first_wrap_fn = checkpoint_wrapper if activation_checkpointing else (lambda module: module)
second_wrap_fn = (
lambda module: offload_wrapper(
first_wrap_fn(module)
if activation_checkpointing else module) # type: ignore reportGeneralTypeIssues
) if activation_cpu_offload else first_wrap_fn
else:
first_wrap_fn = checkpoint_wrapper if activation_checkpointing else (lambda module: module)
second_wrap_fn = (
lambda module: checkpoint_wrapper(
first_wrap_fn(module), # type: ignore reportGeneralTypeIssues
offload_to_cpu=True)) if activation_cpu_offload else first_wrap_fn
if not activation_checkpointing_reentrant:
first_wrap_fn = lambda m: checkpoint_wrapper(m, checkpoint_impl=CheckpointImpl.NO_REENTRANT
) if activation_checkpointing else (lambda module:
module)
second_wrap_fn = (
lambda module: checkpoint_wrapper(
first_wrap_fn(module), # type: ignore reportGeneralTypeIssues
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
offload_to_cpu=True)) if activation_cpu_offload else first_wrap_fn
else:
first_wrap_fn = checkpoint_wrapper if activation_checkpointing else (lambda module: module)
second_wrap_fn = (
lambda module: checkpoint_wrapper(
first_wrap_fn(module), # type: ignore reportGeneralTypeIssues
offload_to_cpu=True)) if activation_cpu_offload else first_wrap_fn

# Choose which modules to activation checkpoint according to the following priority:
# If module has attribute `module._activation_checkpointing = ...`, always respect it
Expand Down
61 changes: 60 additions & 1 deletion tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import pytest
import torch
from packaging import version
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper
from torch.utils.data import DataLoader

from composer.models import ComposerClassifier
from composer.models import ComposerClassifier, ComposerModel
from composer.trainer.trainer import Trainer
from composer.utils import dist
from tests.common import (EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel,
Expand Down Expand Up @@ -117,3 +118,61 @@ def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limi
)

trainer.fit()


class SimpleMLP(ComposerModel):

def __init__(self, num_features: int = 128, device: str = 'cuda'):
super().__init__()
self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False)

def forward(self, x):
x = self.fc1(x)
x = torch.nn.ReLU(x)
x = self.fc2(x)
return x

def loss(self, outputs, batch):
pass


@world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('activation_checkpointing', [True, False])
@pytest.mark.parametrize('activation_cpu_offload', [True, False])
def test_fsdp_act_ckpt_offload(
activation_checkpointing: bool,
activation_cpu_offload: bool,
world_size: int,
):
model = SimpleMLP()

fsdp_config = {
'activation_checkpointing': activation_checkpointing,
'activation_checkpointing_reentrant': False,
'activation_cpu_offload': activation_cpu_offload,
}

model.fc1._activation_checkpointing = True

trainer = Trainer(
model=model,
device='gpu',
fsdp_config=fsdp_config,
)

assert trainer.state.fsdp_enabled
if version.parse(torch.__version__) > version.parse('2.1.0.dev'):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import OffloadWrapper

if activation_checkpointing and activation_cpu_offload:
assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module, OffloadWrapper)
assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module._checkpoint_wrapped_module,
CheckpointWrapper)
elif activation_checkpointing:
assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module, CheckpointWrapper)
elif activation_cpu_offload:
assert isinstance(trainer.state.model.fc1._fsdp_wrapped_module, OffloadWrapper)
else:
assert not isinstance(trainer.state.model.fc1._fsdp_wrapped_module, CheckpointWrapper)
4 changes: 3 additions & 1 deletion tests/utils/eval_client/test_local_eval_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
],
)
@world_size(1, 2)
def test_local_invoke(code: str, result: str, language: str, world_size: int):
def test_local_invoke(code: str, result: str, language: str, world_size: int, tmp_path: str):
"""Test invocation function for LocalEvalClient with code that succeeds, fails compilation, times out, and is incorrect in C, C++, Python, JS.
"""
import os
os.makedirs(os.path.dirname(tmp_path), exist_ok=True)
eval_client = LocalEvalClient()
input = '(1,)' if language == 'python' else '1'
assert eval_client.invoke([[[{
Expand Down

0 comments on commit 4dcbc2b

Please sign in to comment.