Skip to content

Commit

Permalink
Remove _data attribute from DataSampler (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho authored Jan 24, 2024
1 parent 39f896e commit 39cfbbe
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
29 changes: 15 additions & 14 deletions ctgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class DataSampler(object):
"""DataSampler samples the conditional vector and corresponding data for CTGAN."""

def __init__(self, data, output_info, log_frequency):
self._data = data
self._data_length = len(data)

def is_discrete_column(column_info):
return (len(column_info) == 1
Expand Down Expand Up @@ -115,33 +115,34 @@ def sample_original_condvec(self, batch):
if self._n_discrete_columns == 0:
return None

category_freq = self._discrete_column_category_prob.flatten()
category_freq = category_freq[category_freq != 0]
category_freq = category_freq / np.sum(category_freq)
col_idxs = np.random.choice(np.arange(len(category_freq)), batch, p=category_freq)
cond = np.zeros((batch, self._n_categories), dtype='float32')

for i in range(batch):
row_idx = np.random.randint(0, len(self._data))
col_idx = np.random.randint(0, self._n_discrete_columns)
matrix_st = self._discrete_column_matrix_st[col_idx]
matrix_ed = matrix_st + self._discrete_column_n_category[col_idx]
pick = np.argmax(self._data[row_idx, matrix_st:matrix_ed])
cond[i, pick + self._discrete_column_cond_st[col_idx]] = 1
cond[np.arange(batch), col_idxs] = 1

return cond

def sample_data(self, n, col, opt):
def sample_data(self, data, n, col, opt):
"""Sample data from original training data satisfying the sampled conditional vector.
Args:
data:
The training data.
Returns:
n rows of matrix data.
n:
n rows of matrix data.
"""
if col is None:
idx = np.random.randint(len(self._data), size=n)
return self._data[idx]
idx = np.random.randint(len(data), size=n)
return data[idx]

idx = []
for c, o in zip(col, opt):
idx.append(np.random.choice(self._rid_by_cat_cols[c][o]))

return self._data[idx]
return data[idx]

def dim_cond_vec(self):
"""Return the total number of categories."""
Expand Down
5 changes: 3 additions & 2 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
condvec = self._data_sampler.sample_condvec(self._batch_size)
if condvec is None:
c1, m1, col, opt = None, None, None, None
real = self._data_sampler.sample_data(self._batch_size, col, opt)
real = self._data_sampler.sample_data(
train_data, self._batch_size, col, opt)
else:
c1, m1, col, opt = condvec
c1 = torch.from_numpy(c1).to(self._device)
Expand All @@ -365,7 +366,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
perm = np.arange(self._batch_size)
np.random.shuffle(perm)
real = self._data_sampler.sample_data(
self._batch_size, col[perm], opt[perm])
train_data, self._batch_size, col[perm], opt[perm])
c2 = c1[perm]

fake = self._generator(fakez)
Expand Down

0 comments on commit 39cfbbe

Please sign in to comment.