diff --git a/nuwa_pytorch/train_vqgan_vae.py b/nuwa_pytorch/train_vqgan_vae.py index de71468..066a974 100644 --- a/nuwa_pytorch/train_vqgan_vae.py +++ b/nuwa_pytorch/train_vqgan_vae.py @@ -291,7 +291,9 @@ def train_step(self): # update discriminator if exists(self.vae.discr): + self.discr_optim.zero_grad() discr_loss = 0 + for _ in range(self.grad_accum_every): img = next(self.dl) img = img.to(device) @@ -302,7 +304,6 @@ def train_step(self): (loss / self.grad_accum_every).backward() self.discr_optim.step() - self.discr_optim.zero_grad() # log diff --git a/setup.py b/setup.py index 061165a..eda9ef1 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'nuwa-pytorch', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.7.7', + version = '0.7.8', license='MIT', description = 'NÜWA - Pytorch', long_description_content_type = 'text/markdown',