Skip to content

Commit

Permalink
Revert Sampling in Bernoulli model
Browse files Browse the repository at this point in the history
  • Loading branch information
successar committed Aug 15, 2020
1 parent 7573197 commit 894bfba
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion Rationale_model/models/classifiers/enc_gen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def forward(self, document, kept_tokens, rationale=None, label=None, metadata=No
prob_z = kept_tokens.float() + prob_z * (1 - kept_tokens)
sampler = D.bernoulli.Bernoulli(probs=prob_z)

sample_z = mask.float()
sample_z = sampler.sample() * mask.float()
encoder_dict = self._encoder(sample_z=sample_z, label=label, metadata=metadata)

loss = 0.0
Expand Down

0 comments on commit 894bfba

Please sign in to comment.