Skip to content

Commit

Permalink
pin to newer version of CLIP that returns encoded text and images, ge…
Browse files Browse the repository at this point in the history
…t some helper functions ready for XCLIP
  • Loading branch information
lucidrains committed Apr 12, 2022
1 parent 0070547 commit 4ff6d02
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
42 changes: 42 additions & 0 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,48 @@
from torch import nn, einsum
from einops import rearrange

# use x-clip

from x_clip import CLIP

# helper functions

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

# for controlling freezing of CLIP

def set_module_requires_grad_(module, requires_grad):
for param in module.parameters():
param.requires_grad = requires_grad

def freeze_all_layers_(module):
set_module_requires_grad_(module, False)

def unfreeze_all_layers_(module):
set_module_requires_grad_(module, True)

# diffusion prior

class DiffusionPrior(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x

# decoder

class Decoder(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x

# main class

class DALLE2(nn.Module):
def __init__(self):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
install_requires=[
'einops>=0.4',
'torch>=1.6',
'x-clip'
'x-clip>=0.4.1'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit 4ff6d02

Please sign in to comment.