diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index bbcac8f3..e40b4bdc 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -39,7 +39,7 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb interpolates = alpha * real_data + ((1 - alpha) * fake_data) disc_interpolates = self(interpolates) - self.set_device(device) + torch.cuda.set_device(0) gradients = torch.autograd.grad( outputs=disc_interpolates, inputs=interpolates,