From 4ff6d021c99c16f3faa087852d34224109edbfb3 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Apr 2022 08:54:40 -0700 Subject: [PATCH] pin to newer version of CLIP that returns encoded text and images, get some helper functions ready for XCLIP --- dalle2_pytorch/dalle2_pytorch.py | 42 ++++++++++++++++++++++++++++++++ setup.py | 2 +- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 88865e5a..f461d2ee 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -3,6 +3,48 @@ from torch import nn, einsum from einops import rearrange +# use x-clip + +from x_clip import CLIP + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# for controlling freezing of CLIP + +def set_module_requires_grad_(module, requires_grad): + for param in module.parameters(): + param.requires_grad = requires_grad + +def freeze_all_layers_(module): + set_module_requires_grad_(module, False) + +def unfreeze_all_layers_(module): + set_module_requires_grad_(module, True) + +# diffusion prior + +class DiffusionPrior(nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + return x + +# decoder + +class Decoder(nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + return x + +# main class + class DALLE2(nn.Module): def __init__(self): super().__init__() diff --git a/setup.py b/setup.py index 10446bc6..f294d35d 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ install_requires=[ 'einops>=0.4', 'torch>=1.6', - 'x-clip' + 'x-clip>=0.4.1' ], classifiers=[ 'Development Status :: 4 - Beta',