From 02459b125e62afc6d126466843d02c022b01a83a Mon Sep 17 00:00:00 2001 From: Afonja Tejumade Date: Fri, 13 Oct 2023 18:12:47 +0200 Subject: [PATCH] returns the value of the tensors as a standard Python number I assume that we are interested in the standard Python number, and not the Tensor value. --- ctgan/synthesizers/ctgan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 82dd3aba..e450c5c7 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -422,8 +422,8 @@ def fit(self, train_data, discrete_columns=(), epochs=None): loss_g.backward() optimizerG.step() - generator_loss = loss_g.detach().cpu() - discriminator_loss = loss_d.detach().cpu() + generator_loss = loss_g.detach().cpu().item() + discriminator_loss = loss_d.detach().cpu().item() epoch_loss_df = pd.DataFrame({ 'Epoch': [i],