Skip to content

Commit

Permalink
test: fix error in test
Browse files Browse the repository at this point in the history
  • Loading branch information
ninpnin committed Jun 3, 2024
1 parent 4d9744d commit b358c2f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,17 @@ def test_vi_convergence(self):
dim = 2

text_full = []
for i in range(100):
for i in range(10):
text_full += text

e0 = Embedding(vocabulary=vocabulary, dimensionality=dim)
init_mean = False
init_std = 0.2
q_mu0, q_std_e0 = mean_field_vi(e, text, model="sgns", evaluate=False, ws=ws, batch_size=batch_size, init_mean=init_mean, init_std=init_std, epochs=5)
q_mu0, q_std_e0 = mean_field_vi(e0, text, model="sgns", evaluate=False, ws=ws, batch_size=batch_size, init_mean=init_mean, init_std=init_std, epochs=5)

e = Embedding(vocabulary=vocabulary, dimensionality=dim)
q_mu, q_std_e = mean_field_vi(e, text_full, model="sgns", evaluate=False, ws=ws, batch_size=batch_size, init_mean=init_mean, init_std=init_std, epochs=5)
self.assertGreater(np.mean(q_std_e.theta), np.mean(q_std_e0.theta))
self.assertGreater(np.mean(q_std_e0.theta), np.mean(q_std_e.theta))


if __name__ == '__main__':
Expand Down

0 comments on commit b358c2f

Please sign in to comment.