diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 427e3e1b..388d8208 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -8,6 +8,7 @@ import torch import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint from torch import nn, einsum import torchvision.transforms as T @@ -108,6 +109,28 @@ def pad_tuple_to_length(t, length, fillvalue = None): return t return (*t, *((fillvalue,) * remain_length)) +# checkpointing helper function + +def make_checkpointable(fn, **kwargs): + if isinstance(fn, nn.ModuleList): + return [maybe(make_checkpointable)(el, **kwargs) for el in fn] + + condition = kwargs.pop('condition', None) + + if exists(condition) and not condition(fn): + return fn + + @wraps(fn) + def inner(*args): + input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args]) + + if not input_needs_grad: + return fn(*args) + + return checkpoint(fn, *args) + + return inner + # for controlling freezing of CLIP def set_module_requires_grad_(module, requires_grad): @@ -1698,6 +1721,7 @@ def __init__( pixel_shuffle_upsample = True, final_conv_kernel_size = 1, combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper + checkpoint_during_training = False, **kwargs ): super().__init__() @@ -1908,6 +1932,10 @@ def __init__( zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it + # whether to checkpoint during training + + self.checkpoint_during_training = checkpoint_during_training + # if the current settings for the unet are not correct # for cascading DDPM, then reinit the unet with the right settings def cast_model_parameters( @@ -1965,7 +1993,8 @@ def forward( image_cond_drop_prob = 0., text_cond_drop_prob = 0., blur_sigma = None, - blur_kernel_size = None + blur_kernel_size = None, + disable_checkpoint = False ): batch_size, device = x.shape[0], x.device @@ -2087,17 +2116,29 @@ def forward( c = self.norm_cond(c) mid_c = self.norm_mid_cond(mid_c) + # gradient checkpointing + + can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint + apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity + + # make checkpointable modules + + init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)] + + can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock) + downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)] + # initial resnet block - if exists(self.init_resnet_block): - x = self.init_resnet_block(x, t) + if exists(init_resnet_block): + x = init_resnet_block(x, t) # go through the layers of the unet, down and up down_hiddens = [] up_hiddens = [] - for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs: + for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs: if exists(pre_downsample): x = pre_downsample(x) @@ -2113,16 +2154,16 @@ def forward( if exists(post_downsample): x = post_downsample(x) - x = self.mid_block1(x, t, mid_c) + x = mid_block1(x, t, mid_c) - if exists(self.mid_attn): - x = self.mid_attn(x) + if exists(mid_attn): + x = mid_attn(x) - x = self.mid_block2(x, t, mid_c) + x = mid_block2(x, t, mid_c) connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1) - for init_block, resnet_blocks, attn, upsample in self.ups: + for init_block, resnet_blocks, attn, upsample in ups: x = connect_skip(x) x = init_block(x, t, c) @@ -2139,7 +2180,7 @@ def forward( x = torch.cat((x, r), dim = 1) - x = self.final_resnet_block(x, t) + x = final_resnet_block(x, t) if exists(lowres_cond_img): x = torch.cat((x, lowres_cond_img), dim = 1) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index adf1ed52..77f1c8e6 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.4.6' +__version__ = '1.5.0'