From 27a33e1b2016ac384174054ba54719408f9883f5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 20 Apr 2022 10:46:13 -0700 Subject: [PATCH] complete contextmanager method for keeping only one unet in GPU during training or inference --- README.md | 2 +- dalle2_pytorch/dalle2_pytorch.py | 23 ++++++++++++++++++++--- setup.py | 2 +- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index afae8af6..2ffb0345 100644 --- a/README.md +++ b/README.md @@ -411,8 +411,8 @@ Offer training wrappers - [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions - [x] add efficient attention in unet - [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning) +- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) - [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting) -- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] train on a toy task, offer in colab diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 00a3032c..4bb41cf0 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -2,6 +2,7 @@ from tqdm import tqdm from inspect import isfunction from functools import partial +from contextlib import contextmanager import torch import torch.nn.functional as F @@ -1141,6 +1142,20 @@ def __init__( self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + @contextmanager + def one_unet_in_gpu(self, unet_number): + assert 0 < unet_number <= len(self.unets) + index = unet_number - 1 + self.cuda() + self.unets.cpu() + + unet = self.unets[index] + unet.cuda() + + yield + + self.unets.cpu() + def get_text_encodings(self, text): text_encodings = self.clip.text_transformer(text) return text_encodings[:, 1:] @@ -1245,9 +1260,11 @@ def sample(self, image_embed, text = None, cond_scale = 1.): text_encodings = self.get_text_encodings(text) if exists(text) else None img = None - for unet, image_size in tqdm(zip(self.unets, self.image_sizes)): - shape = (batch_size, channels, image_size, image_size) - img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img) + + for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))): + with self.one_unet_in_gpu(ind + 1): + shape = (batch_size, channels, image_size, image_size) + img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img) return img diff --git a/setup.py b/setup.py index 41e80b0e..e6d66c96 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.27', + version = '0.0.28', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',