diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 82dd3aba..e450c5c7 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -422,8 +422,8 @@ def fit(self, train_data, discrete_columns=(), epochs=None): loss_g.backward() optimizerG.step() - generator_loss = loss_g.detach().cpu() - discriminator_loss = loss_d.detach().cpu() + generator_loss = loss_g.detach().cpu().item() + discriminator_loss = loss_d.detach().cpu().item() epoch_loss_df = pd.DataFrame({ 'Epoch': [i],