Skip to content

Commit

Permalink
complete contextmanager method for keeping only one unet in GPU durin…
Browse files Browse the repository at this point in the history
…g training or inference
  • Loading branch information
lucidrains committed Apr 20, 2022
1 parent 6f941a2 commit 27a33e1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ Offer training wrappers
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab

Expand Down
23 changes: 20 additions & 3 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from contextlib import contextmanager

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -1141,6 +1142,20 @@ def __init__(
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

@contextmanager
def one_unet_in_gpu(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
self.cuda()
self.unets.cpu()

unet = self.unets[index]
unet.cuda()

yield

self.unets.cpu()

def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
Expand Down Expand Up @@ -1245,9 +1260,11 @@ def sample(self, image_embed, text = None, cond_scale = 1.):
text_encodings = self.get_text_encodings(text) if exists(text) else None

img = None
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)

for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
with self.one_unet_in_gpu(ind + 1):
shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)

return img

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.27',
version = '0.0.28',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit 27a33e1

Please sign in to comment.