Skip to content

Commit

Permalink
Remove 0's from category_freq
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Jan 22, 2024
1 parent aae2bbc commit b81fc6e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion ctgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def sample_original_condvec(self, batch):
return None

category_freq = self._discrete_column_category_prob.flatten()
category_freq = category_freq[category_freq != 0]
category_freq = category_freq / np.sum(category_freq)
col_idxs = np.random.choice(np.arange(len(category_freq)), batch, p=category_freq)

cond = np.zeros((batch, self._n_categories), dtype='float32')
cond[np.arange(batch), col_idxs] = 1

Expand Down

0 comments on commit b81fc6e

Please sign in to comment.