Skip to content

Commit

Permalink
add gradient checkpointing for all resnet blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 3, 2022
1 parent 451de34 commit be3bb86
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 11 deletions.
61 changes: 51 additions & 10 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch import nn, einsum
import torchvision.transforms as T

Expand Down Expand Up @@ -108,6 +109,28 @@ def pad_tuple_to_length(t, length, fillvalue = None):
return t
return (*t, *((fillvalue,) * remain_length))

# checkpointing helper function

def make_checkpointable(fn, **kwargs):
if isinstance(fn, nn.ModuleList):
return [maybe(make_checkpointable)(el, **kwargs) for el in fn]

condition = kwargs.pop('condition', None)

if exists(condition) and not condition(fn):
return fn

@wraps(fn)
def inner(*args):
input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])

if not input_needs_grad:
return fn(*args)

return checkpoint(fn, *args)

return inner

# for controlling freezing of CLIP

def set_module_requires_grad_(module, requires_grad):
Expand Down Expand Up @@ -1698,6 +1721,7 @@ def __init__(
pixel_shuffle_upsample = True,
final_conv_kernel_size = 1,
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
checkpoint_during_training = False,
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -1908,6 +1932,10 @@ def __init__(

zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it

# whether to checkpoint during training

self.checkpoint_during_training = checkpoint_during_training

# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
Expand Down Expand Up @@ -1965,7 +1993,8 @@ def forward(
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
blur_kernel_size = None,
disable_checkpoint = False
):
batch_size, device = x.shape[0], x.device

Expand Down Expand Up @@ -2087,17 +2116,29 @@ def forward(
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)

# gradient checkpointing

can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint
apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity

# make checkpointable modules

init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]

can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)
downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]

# initial resnet block

if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t)
if exists(init_resnet_block):
x = init_resnet_block(x, t)

# go through the layers of the unet, down and up

down_hiddens = []
up_hiddens = []

for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:
if exists(pre_downsample):
x = pre_downsample(x)

Expand All @@ -2113,16 +2154,16 @@ def forward(
if exists(post_downsample):
x = post_downsample(x)

x = self.mid_block1(x, t, mid_c)
x = mid_block1(x, t, mid_c)

if exists(self.mid_attn):
x = self.mid_attn(x)
if exists(mid_attn):
x = mid_attn(x)

x = self.mid_block2(x, t, mid_c)
x = mid_block2(x, t, mid_c)

connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)

for init_block, resnet_blocks, attn, upsample in self.ups:
for init_block, resnet_blocks, attn, upsample in ups:
x = connect_skip(x)
x = init_block(x, t, c)

Expand All @@ -2139,7 +2180,7 @@ def forward(

x = torch.cat((x, r), dim = 1)

x = self.final_resnet_block(x, t)
x = final_resnet_block(x, t)

if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.4.6'
__version__ = '1.5.0'

0 comments on commit be3bb86

Please sign in to comment.