diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py index c1e136b5..dd781057 100644 --- a/ctgan/data_transformer.py +++ b/ctgan/data_transformer.py @@ -189,7 +189,8 @@ def transform(self, raw_data): def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st): gm = column_transform_info.transform - data = pd.DataFrame(column_data[:, :2], columns=list(gm.get_output_sdtypes())) + data = pd.DataFrame( + column_data[:, :2], columns=list(gm.get_output_sdtypes())).astype(float) data[data.columns[1]] = np.argmax(column_data[:, 1:], axis=1) if sigmas is not None: selected_normalized_value = np.random.normal(data.iloc[:, 0], sigmas[st])