From a1961bbf82e814921b6fcb974100da53ea8468ae Mon Sep 17 00:00:00 2001 From: Neil Date: Sat, 28 Sep 2024 14:59:06 +0800 Subject: [PATCH] test(kzg_hiding): add unit tests for KZG10Commitment Implement comprehensive test suite for KZG10Commitment class, covering setup, commit, open, check, and batch check operations. Include tests for both valid and invalid proofs. --- tests/test_kzg_hiding.py | 94 +++++++++++++++++++++++++++++++++++++ tests/test_unipolynomial.py | 13 ----- 2 files changed, 94 insertions(+), 13 deletions(-) create mode 100644 tests/test_kzg_hiding.py diff --git a/tests/test_kzg_hiding.py b/tests/test_kzg_hiding.py new file mode 100644 index 0000000..a812dee --- /dev/null +++ b/tests/test_kzg_hiding.py @@ -0,0 +1,94 @@ +import unittest +from random import randint +import sys +sys.path.append('../src') +sys.path.append('src') + +from kzg_hiding import KZG10Commitment, UniPolynomial, DummyGroup, Field, Commitment + +class TestKZG10Commitment(unittest.TestCase): + def setUp(self): + self.max_degree = 10 + self.hiding_bound = 3 + self.kzg = KZG10Commitment(DummyGroup(Field), DummyGroup(Field), self.max_degree, hiding_bound=self.hiding_bound, debug=True) + + def test_setup(self): + params = self.kzg.setup() + self.assertIn('powers_of_g', params) + self.assertIn('powers_of_gamma_g', params) + self.assertIn('h', params) + self.assertIn('beta_h', params) + self.assertIn('neg_powers_of_h', params) + self.assertEqual(len(params['powers_of_g']), self.max_degree + 1) + self.assertEqual(len(params['powers_of_gamma_g']), self.max_degree + 2) + + def test_commit(self): + poly = UniPolynomial([randint(0, 100) for _ in range(5)]) + commitment, random_ints = self.kzg.commit(poly) + self.assertIsInstance(commitment, Commitment) + self.assertEqual(len(random_ints), self.hiding_bound + 1) + + def test_open(self): + poly = UniPolynomial([randint(0, 100) for _ in range(5)]) + point = randint(0, 100) + commitment, random_ints = self.kzg.commit(poly) + proof = self.kzg.open(poly, point, random_ints) + self.assertIn('w', proof) + self.assertIn('random_v', proof) + + def test_check(self): + poly = UniPolynomial([randint(0, 100) for _ in range(5)]) + point = randint(0, 100) + commitment, random_ints = self.kzg.commit(poly) + value = poly.evaluate(point) + proof = self.kzg.open(poly, point, random_ints) + self.assertTrue(self.kzg.check(commitment, point, value, proof)) + + def test_check_invalid_proof(self): + poly = UniPolynomial([randint(0, 100) for _ in range(5)]) + point = randint(0, 100) + commitment, random_ints = self.kzg.commit(poly) + value = poly.evaluate(point) + invalid_proof = self.kzg.open(poly, point + 1, random_ints) # Invalid point + self.assertFalse(self.kzg.check(commitment, point, value, invalid_proof)) + + def test_batch_check(self): + num_polynomials = 5 + polynomials = [UniPolynomial([randint(0, 100) for _ in range(randint(5, 10))]) for _ in range(num_polynomials)] + points = [randint(0, 100) for _ in range(num_polynomials)] + + commitments = [] + values = [] + proofs = [] + + for p, point in zip(polynomials, points): + comm, random_ints = self.kzg.commit(p) + commitments.append(comm) + values.append(p.evaluate(point)) + proofs.append(self.kzg.open(p, point, random_ints)) + + self.assertTrue(self.kzg.batch_check(commitments, points, values, proofs)) + + def test_batch_check_invalid_proof(self): + num_polynomials = 5 + polynomials = [UniPolynomial([randint(0, 100) for _ in range(randint(5, 10))]) for _ in range(num_polynomials)] + points = [randint(0, 100) for _ in range(num_polynomials)] + + commitments = [] + values = [] + proofs = [] + + for p, point in zip(polynomials, points): + comm, random_ints = self.kzg.commit(p) + commitments.append(comm) + values.append(p.evaluate(point)) + proofs.append(self.kzg.open(p, point, random_ints)) + + # Invalidate one proof + invalid_index = randint(0, num_polynomials - 1) + proofs[invalid_index] = self.kzg.open(polynomials[invalid_index], points[invalid_index] + 1, random_ints) + + self.assertFalse(self.kzg.batch_check(commitments, points, values, proofs)) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_unipolynomial.py b/tests/test_unipolynomial.py index 217e132..7c3586b 100644 --- a/tests/test_unipolynomial.py +++ b/tests/test_unipolynomial.py @@ -211,19 +211,6 @@ def test_compute_eval(self): self.assertEqual(evals, expected) - # def test_compute_linear_combination_linear_moduli(self): - # from sage.all import GF - # from field import magic - - # Fp = magic(GF(193)) - # domain = [Fp(1), Fp(2), Fp(3), Fp(4)] - # tree = UniPolynomial.construct_subproduct_tree_fix(domain) - # ws = [Fp(1), Fp(8), Fp(27), Fp(64)] - # result = UniPolynomial.compute_linear_combination_linear_moduli_fix(tree, ws, domain) - # expected = UniPolynomial.compute_eval_fix(tree, result, domain) - - # self.assertEqual(ws, expected) - def test_compute_z_derivative(self): z = [1, 2, 3, 4] result = UniPolynomial.compute_z_derivative(z)