diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 1045f4a5..c1bf9f2b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -522,7 +522,7 @@ def predict_start_from_noise(self, x_t, t, noise): def predict_noise_from_start(self, x_t, t, x0): return ( - (x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \ + (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) ) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index e11448a9..a636f701 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.25.1' +__version__ = '0.25.2'