diff --git a/tests/vi.py b/tests/vi.py index 9221dc3..5904136 100644 --- a/tests/vi.py +++ b/tests/vi.py @@ -89,17 +89,17 @@ def test_vi_convergence(self): dim = 2 text_full = [] - for i in range(100): + for i in range(10): text_full += text e0 = Embedding(vocabulary=vocabulary, dimensionality=dim) init_mean = False init_std = 0.2 - q_mu0, q_std_e0 = mean_field_vi(e, text, model="sgns", evaluate=False, ws=ws, batch_size=batch_size, init_mean=init_mean, init_std=init_std, epochs=5) + q_mu0, q_std_e0 = mean_field_vi(e0, text, model="sgns", evaluate=False, ws=ws, batch_size=batch_size, init_mean=init_mean, init_std=init_std, epochs=5) e = Embedding(vocabulary=vocabulary, dimensionality=dim) q_mu, q_std_e = mean_field_vi(e, text_full, model="sgns", evaluate=False, ws=ws, batch_size=batch_size, init_mean=init_mean, init_std=init_std, epochs=5) - self.assertGreater(np.mean(q_std_e.theta), np.mean(q_std_e0.theta)) + self.assertGreater(np.mean(q_std_e0.theta), np.mean(q_std_e.theta)) if __name__ == '__main__':