From 3ad1f48c140fa2ac41f51ff41a8fd01174c413ea Mon Sep 17 00:00:00 2001 From: Felipe Date: Mon, 29 Jul 2024 12:15:57 -0300 Subject: [PATCH] . --- ctgan/synthesizers/ctgan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index bbcac8f3..e40b4bdc 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -39,7 +39,7 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb interpolates = alpha * real_data + ((1 - alpha) * fake_data) disc_interpolates = self(interpolates) - self.set_device(device) + torch.cuda.set_device(0) gradients = torch.autograd.grad( outputs=disc_interpolates, inputs=interpolates,