Skip to content

Commit

Permalink
fix everything and make sure it runs end to end, document everything …
Browse files Browse the repository at this point in the history
…in readme for public
  • Loading branch information
lucidrains committed Apr 14, 2022
1 parent e5e4152 commit a1a8a78
Show file tree
Hide file tree
Showing 4 changed files with 364 additions and 73 deletions.
283 changes: 276 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,296 @@ For all of you emailing me (there is a lot), the best way to contribute is throu
$ pip install dalle2-pytorch
```

## Usage (work in progress)

<a href="https://github.com/lucidrains/big-sleep">template</a>
## CLI Usage (work in progress)

```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```

Once built, images will be saved to the same directory the command is invoked

## Training (work in progress, will offer both in code and as command-line)
## Training (for deep learning practitioners)

<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important

To train CLIP, you can either use `x-clip` package, or join the LAION discord, where a lot of replication efforts are already underway.

This repository will demonstrate integration with `x-clip` for starters

```python
import torch
from dalle2_pytorch import CLIP

clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8,
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
use_visual_ssl = True, # whether to do self supervised learning on iages
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
text_ssl_loss_weight = 0.05, # weight for text MLM loss
image_ssl_loss_weight = 0.05 # weight for image self-supervised learning loss
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train

loss = clip(
text,
images,
return_loss = True # needs to be set to True to return contrastive loss
)

loss.backward()

# do the above with as many texts and images as possible in a loop
```

Then, you will need to train the decoder, which learns to generate images based on the image embedding coming from the trained CLIP above

```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP

# trained clip from step 1

clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()

# unet for the decoder

unet = Unet(
dim = 128,
image_embed_dim = 512,
time_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()

# decoder, which contains the unet and clip

decoder = Decoder(
net = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()

# mock images (get a lot of this)

images = torch.randn(4, 3, 256, 256).cuda()

# feed images into decoder

loss = decoder(images)
loss.backward()

# do the above for many many many many steps
# then it will learn to generate images based on the CLIP image embeddings
```

Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP fron the first step

```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP

clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8,
).cuda()

# setup prior network, which contains an autoregressive transformer

prior_network = DiffusionPriorNetwork(
dim = 512,
num_timesteps = 100,
depth = 6,
dim_head = 64,
heads = 8
).cuda()

# diffusion prior network, which contains the CLIP and network (with transformer) above

diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed text and images into diffusion prior network

Todo
loss = diffusion_prior(text, images)
loss.backward()

# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings
```

Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)

```python
from dalle2_pytorch import DALLE2

dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)

# send the text as a string if you want to use the simple tokenizer from DALL-E1
# or you can do it as token ids, if you have your own tokenizer

texts = ['glistening morning dew on a flower petal']
images = dalle2(texts) # (1, 3, 256, 256)
```

That's it!

Let's see the whole script below

```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP

import torch

clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train

loss = clip(
text,
images,
return_loss = True
)

loss.backward()

# do above for many steps ...

# prior networks (with transformer)

prior_network = DiffusionPriorNetwork(
dim = 512,
num_timesteps = 100,
depth = 6,
dim_head = 64,
heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()

loss = diffusion_prior(text, images)
loss.backward()

# do above for many steps ...

# decoder (with unet)

unet = Unet(
dim = 128,
image_embed_dim = 512,
time_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()

decoder = Decoder(
net = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()

loss = decoder(images)
loss.backward()

# do above for many steps

dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)

images = dalle2(['cute puppy chasing after a squirrel'])

# save your image
```

Everything in this readme should run without error

## Training CLI (wip)

<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>

## Todo

- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
- [ ] make sure it works end to end to produce an output tensor, taking a single gradient step
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
Expand Down
3 changes: 2 additions & 1 deletion dalle2_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from x_clip import CLIP
Loading

0 comments on commit a1a8a78

Please sign in to comment.