diff --git a/tests/test_Basefold.py b/tests/test_Basefold.py index f0b978f..0de174c 100644 --- a/tests/test_Basefold.py +++ b/tests/test_Basefold.py @@ -9,6 +9,8 @@ from merlin.merlin_transcript import MerlinTranscript from merkle import MerkleTree from mle2 import MLEPolynomial +from unipolynomial import UniPolynomial +from utils import log_2 class BasefoldTest(TestCase): def test_rep_encode(self): @@ -107,18 +109,25 @@ def test_basefold(self): self.assertTrue(Basefold.verify_basefold_evaluation_arg_multilinear_basis(2 ** num_vars * blowup_factor, commit, proof, us, v, 2, k, T, blowup_factor, num_verifier_queries)) def test_basefold_fri_monomial_basis(self): - vs = [1, 2, 3, 4] - table = [1, 2] - alpha = 2 + UniPolynomial.scalar_constructor = lambda x: x + table = [randint(1, 100) for _ in range(10)] + coeffs = [randint(0, 100) for _ in range(2)] + vs = [UniPolynomial(coeffs).evaluate(x) for x in table + [-x for x in table]] + alpha = randint(0, 100) result = Basefold.basefold_fri_monomial_basis(vs, table, alpha) - self.assertEqual(len(result), 2) + eval = coeffs[0] + alpha * coeffs[1] + for e in result: + self.assertEqual(e, eval) def test_basefold_fri_multilinear_basis(self): - vs = [1, 2, 3, 4] - table = [1, 2] - c = 2 + table = [randint(1, 100) for _ in range(10)] + coeffs = [randint(0, 100) for _ in range(2)] + vs = [UniPolynomial(coeffs).evaluate(x) for x in table + [-x for x in table]] + c = randint(0, 100) result = Basefold.basefold_fri_multilinear_basis(vs, table, c) - self.assertEqual(len(result), 2) + eval = (1 - c) * coeffs[0] + c * coeffs[1] + for e in result: + self.assertEqual(e, eval) if __name__ == '__main__':