Skip to content

Commit

Permalink
Move contraint reverese_transform order
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Jul 21, 2023
1 parent 64c5d9a commit 5490c6a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
10 changes: 7 additions & 3 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,9 +710,6 @@ def reverse_transform(self, data, reset_keys=False):
except rdt.errors.NotFittedError:
LOGGER.info(f'HyperTransformer has not been fitted for table {self.table_name}')

for constraint in reversed(self._constraints_to_reverse):
reversed_data = constraint.reverse_transform(reversed_data)

num_rows = len(reversed_data)
sampled_columns = list(reversed_data.columns)
missing_columns = [
Expand All @@ -731,6 +728,13 @@ def reverse_transform(self, data, reset_keys=False):
generated_keys = self.generate_keys(num_rows, reset_keys)
sampled_columns.extend(self._keys)

for constraint in reversed(self._constraints_to_reverse):
reversed_data = constraint.reverse_transform(reversed_data)

# Add new columns generated by the constraint
new_columns = list(set(reversed_data.columns) - set(sampled_columns))
sampled_columns.extend(new_columns)

# Sort the sampled columns in the order of the metadata
# In multitable there may be missing columns in the sample such as foreign keys
# And alternate keys. Thats the reason of ensuring that the metadata column is within
Expand Down
17 changes: 17 additions & 0 deletions sdv/data_processing/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from sdv.data_processing.data_processor import DataProcessor
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.copulas import GaussianCopulaSynthesizer
import pandas as pd

data = pd.DataFrame({
'low': [1, 2, 3],
})
metadata = SingleTableMetadata()
metadata.add_column('low', sdtype='numerical')
metadata.update_column('low', sdtype='job', pii=True)

dp = DataProcessor(metadata)
dp.fit(data)
transformed = dp.transform(data)
reverse_transformed = dp.reverse_transform(transformed)
print(reverse_transformed)

0 comments on commit 5490c6a

Please sign in to comment.