Skip to content

Commit

Permalink
fix(circle unipoly kzg): fix bugs by testing
Browse files Browse the repository at this point in the history
  • Loading branch information
ahy231 committed Nov 15, 2024
1 parent 15ec106 commit f01201f
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 56 deletions.
27 changes: 12 additions & 15 deletions src/circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def group_mul(g1, g2):

def pi(t):
# x^2 - y^2 == 2 * x^2 - 1 (x^2 + y^2 = 1)
return C31(2 * t**2 - 1)
return F31(2 * t**2 - 1)

def pie_group(D):
D_new = []
Expand Down Expand Up @@ -162,12 +162,12 @@ def deep_quotient_reduce(evals, domain, alpha, zeta, p_at_zeta, debug=False):
vp_denom_invs = batch_multiplicative_inverse(vp_demons)
if debug: print('vp_nums:', vp_nums, 'vp_denom_invs:', vp_denom_invs, 'p_at_zeta:', p_at_zeta, 'evals:', evals)

return [vp_denom_invs[i] * vp_nums[i] * group_mul(group_inv(p_at_zeta), evals[i]) for i in range(len(evals))]
return [vp_denom_invs[i] * vp_nums[i] * (evals[i] - p_at_zeta) for i in range(len(evals))]

def deep_quotient_reduce_row(alpha, x, zeta, ps_at_x, ps_at_zeta, debug=False):
vp_num, vp_denom = deep_quotient_vanishing_part(x, zeta, alpha)
if debug: print('vp_num:', vp_num, 'vp_denom:', vp_denom, 'ps_at_x:', ps_at_x, 'ps_at_zeta:', ps_at_zeta)
return vp_num * group_mul(group_inv(ps_at_zeta), ps_at_x) / vp_denom
return vp_num * (ps_at_x - ps_at_zeta) / vp_denom

def deep_quotient_reduce_raw(evals, domain, alpha, zeta, p_at_zeta, debug=False):
res = []
Expand Down Expand Up @@ -243,8 +243,8 @@ def _ifft_first_step(cls, f):
for t in f:
x, y = t

f0[x] = (f[t] + f[group_inv(t)]) / C31(2)
f1[x] = (f[t] - f[group_inv(t)]) / (C31(2) * y)
f0[x] = (f[t] + f[group_inv(t)]) / F31(2)
f1[x] = (f[t] - f[group_inv(t)]) / (F31(2) * y)

# Check that f is divided into 2 parts correctly
assert f[t] == f0[x] + y * f1[x]
Expand All @@ -269,6 +269,7 @@ def _ifft_normal_step(cls, f):

# Check that f is divided into 2 parts correctly
assert f[x] == f0[pi(x)] + x * f1[pi(x)]
assert f[-x] == f0[pi(x)] - x * f1[pi(x)]

return cls._ifft_normal_step(f0) + cls._ifft_normal_step(f1)

Expand Down Expand Up @@ -392,7 +393,7 @@ def fold_y(cls, evals, domain, beta, debug=False):
for i, (_, y) in enumerate(domain[:N//2]):
f0 = (left[i] + right[i]) / 2
f1 = (left[i] - right[i]) / (2 * y)
evals[i] = f0 + group_mul(beta, f1)
evals[i] = f0 + f1 * beta
if debug: print('fold y')
if debug: print(f"f0 = (({left[i]}) + ({right[i]}))/2 = {f0}")
if debug: print(f"f1 = (({left[i]}) - ({right[i]}))/(2 * {y}) = {f1}")
Expand Down Expand Up @@ -435,7 +436,7 @@ def fold_x(cls, f, D, r, debug=False):
f1 = (left[i] - right[i]) * (1 / (F31(2) * x))
# f[:N//2] stores the folded polynomial
if debug: print('fold x')
if debug: print(f"f[{i}] = {f[i]} = ({left[i]} + {right[i]})/2 + {r} * ({left[i]} - {right[i]})/(2 * {x})")
if debug: print(f"f[{i}] = {f[i]} = (({left[i]}) + ({right[i]}))/2 + {r} * (({left[i]}) - ({right[i]}))/(2 * {x})")
f[i] = f0 + r * f1
# if debug: print(f"{f[i]} = ({left[i]} + {right[i]})/2 + {r} * ({left[i]} - {right[i]})/(2 * {x})")
# reuse f[N//2:] to store new domain
Expand Down Expand Up @@ -581,7 +582,7 @@ def verify_query(cls, index, steps, reduced_opening, log_max_height, debug=False
return folded_eval

@classmethod
def verify(cls, proof, transcript, open_input, debug=False):
def verify(cls, proof, blowup_factor, transcript, open_input, debug=False):
assert isinstance(transcript, MerlinTranscript), "transcript should be a MerlinTranscript"

betas = []
Expand All @@ -593,11 +594,11 @@ def verify(cls, proof, transcript, open_input, debug=False):
transcript.append_message(b"final_poly", bytes(str(proof["final_poly"]), 'ascii'))

folded_eval = 0
log_max_height = len(proof["commit_phase_commits"]) + log_2(blowup_factor)
for qp in proof["query_proofs"]:
index = int.from_bytes(transcript.challenge_bytes(b"query", 4), "big")
if debug: print('query:', index)

log_max_height = len(proof["commit_phase_commits"])
index >>= (32 - log_max_height - 1)
index_sibling = (1 << log_max_height) * 2 - 1 - index
if debug: print('log_max_height:', log_max_height, ', index:', index, ', index_sibling:', index_sibling)
Expand Down Expand Up @@ -747,19 +748,15 @@ def open_input(index, input_proof):

return fri_input

FRI.verify(proof["fri_proof"], transcript, open_input, debug)
FRI.verify(proof["fri_proof"], 1 << log_blowup, transcript, open_input, debug)

if __name__ == "__main__":
from random import randint
rand_ext = lambda: randint(0, 31) + randint(0, 31) * I
# evals = [rand_ext() for _ in range(2)]
evals = [F31(1), F31(2), F31(3), F31(4)]
evals = [F31(randint(0, 31)) for _ in range(4)]
domain = CirclePCS.natural_domain_for_degree(len(evals))
print('domain:', domain)
log_blowup = 1

commitment, lde = CirclePCS.commit(evals, domain, 1 << log_blowup)
print('lde:', lde)

transcript = MerlinTranscript(b'circle pcs')
transcript.append_message(b'commitment', bytes(str(commitment.root), 'ascii'))
Expand Down
104 changes: 83 additions & 21 deletions tests/test_circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,25 @@
sys.path.append('src')
sys.path.append('../src')

from circle import CFFT, F31, C31, FRI, eval_at_point_raw, twin_cosets, combine, standard_position_cosets, log_2, deep_quotient_reduce, deep_quotient_reduce_raw, g_30
from circle import CirclePCS, CFFT, F31, FRI, evaluate_at_point, eval_at_point_raw, extract_lambda, twin_cosets, combine, standard_position_cosets, log_2, deep_quotient_reduce, deep_quotient_reduce_raw, g_30, group_mul, group_inv
from merlin.merlin_transcript import MerlinTranscript

def fold(lde, domain, chunk_size, fold_y=False):
def fold(lde, domain, alpha=3, fold_y=False):
if fold_y:
assert len(domain) == len(lde), f'len(domain) != len(lde), {len(domain)}, {len(lde)}'
else:
assert len(domain) == len(lde) * 2, f'len(domain) != len(lde) * 2, {len(domain)}, {len(lde) * 2}'
res = []
for j in range(len(lde) // chunk_size):
for i in range(chunk_size // 2):
left = lde[j * chunk_size + i]
right = lde[(j + 1) * chunk_size - i - 1]
t = domain[i][1 if fold_y else 0]
# print('t:', t)
f0 = (left + right) / F31(2)
f1 = (left - right) / (F31(2) * t)
assert f0 + f1 * t == left
assert f0 - f1 * t == right
res += [f0 + f1 * 3]
for i in range(len(lde) // 2):
left = lde[i]
right = lde[-i - 1]
t = domain[i][1 if fold_y else 0]
# print('t:', t)
f0 = (left + right) / F31(2)
f1 = (left - right) / (F31(2) * t)
assert f0 + f1 * t == left
assert f0 - f1 * t == right
res += [f0 + f1 * alpha]
return res

class TestCircle(unittest.TestCase):
Expand Down Expand Up @@ -54,8 +53,8 @@ def test_fold(self):

domain_lde = standard_position_cosets[log_2(len(evals) * blowup_factor)]
# print('domain_lde:', domain_lde)
folded = fold(lde, domain_lde, len(lde), fold_y=True)
folded_folded = fold(folded, domain_lde, len(lde) // 2, fold_y=False)
folded = fold(lde, domain_lde, 3, fold_y=True)
folded_folded = fold(folded, domain_lde, 3, fold_y=False)
assert folded_folded[0] == folded_folded[1], f'folded_folded[0] != folded_folded[1], {folded_folded[0]}, {folded_folded[1]}'

def test_fri_prove(self):
Expand All @@ -66,7 +65,7 @@ def test_fri_prove(self):

domain_lde = standard_position_cosets[log_2(len(evals) * blowup_factor)]
# print('domain_lde:', domain_lde)
folded = fold(lde, domain_lde, len(lde), fold_y=True)
folded = fold(lde, domain_lde, 3, fold_y=True)

f0 = [lde[0]] + lde[3:5] + [lde[7]]
f1 = lde[1:3] + lde[5:7]
Expand All @@ -75,18 +74,81 @@ def test_fri_prove(self):
assert CFFT.ifft(CFFT.vec_2_poly(f0, tcs[0])) == CFFT.ifft(CFFT.vec_2_poly(f1, tcs[1]))

transcript = MerlinTranscript(b'TEST')
_fri_proof = FRI.prove(folded, blowup_factor, [x[0] for x in domain_lde[:len(folded)]], transcript, lambda x: None, 1)
FRI.prove(folded, blowup_factor, [x[0] for x in domain_lde[:len(folded)]], transcript, lambda x: None, 1)

# def test_eval_at_point(self):
# evals = [F31(1), F31(2), F31(3), F31(4)]
# domain = standard_position_cosets[log_2(len(evals))]
# p = evaluate_at_point(evals, domain, g_30 ** 3)
# expected = eval_at_point_raw(evals, domain, g_30 ** 3)
# assert p == expected, f'evaluate_at_point error, {p}, {expected}'

def test_deep_quotient_reduce(self):
evals = [C31(1), C31(2), C31(3), C31(4)]
evals = [F31(1), F31(2), F31(3), F31(4)]
domain = standard_position_cosets[log_2(len(evals))]
alpha = 3
zeta = g_30 ** 6
zeta = g_30 ** 3
p_at_zeta = eval_at_point_raw(evals, domain, zeta)

reduced = deep_quotient_reduce(evals, domain, alpha, zeta, p_at_zeta)
expected = deep_quotient_reduce_raw(evals, domain, alpha, zeta, p_at_zeta)
blowup_factor = 2
lde = CFFT.extrapolate(evals, domain, blowup_factor)
domain_lde = standard_position_cosets[log_2(len(evals) * blowup_factor)]

reduced = deep_quotient_reduce(lde, domain_lde, alpha, zeta, p_at_zeta)
expected = deep_quotient_reduce_raw(lde, domain_lde, alpha, zeta, p_at_zeta)
assert reduced == expected, f'deep_quotient_reduce error, {reduced}, {expected}'

coeffs = CFFT.ifft(CFFT.vec_2_poly(reduced, domain_lde))
for i, c in enumerate(coeffs):
if i % 2 == 1 and i != 1:
assert c == F31(0), f'coeffs[{i}] != 0, coeffs: {coeffs}'

def test_extract_lambda(self):
evals = [F31(1), F31(2), F31(3), F31(4)]
lde = CFFT.extrapolate(evals, standard_position_cosets[2], 2)
domain = standard_position_cosets[3]

zeta = g_30 ** 3
p_at_zeta = eval_at_point_raw(evals, standard_position_cosets[2], zeta)

diff = [group_mul(group_inv(zeta), x) for x in domain]

v_p = lambda x, y: (1 - x, -y)
v_p_diff = [v_p(x, y) for x, y in diff]

v_p_num = [x - y * 5 for x, y in v_p_diff]
v_p_den = [x*x + y*y for x, y in v_p_diff]

q_f = lambda f, fp, vp_num, vp_den: (f - fp) * vp_num / vp_den
q = [q_f(eval, p_at_zeta, num, den) for eval, num, den in zip(lde, v_p_num, v_p_den)]

new_lde, _ = extract_lambda(q, 1)

q_interpolated = CFFT.ifft(CFFT.vec_2_poly(new_lde, standard_position_cosets[3]))
for i, c in enumerate(q_interpolated):
if i % 2 == 1:
assert c == F31(0), f'coeffs[{i}] != 0, coeffs: {q_interpolated}'

def test_circle_pcs(self):
from random import randint
evals = [F31(randint(0, 31)) for _ in range(4)]
domain = CirclePCS.natural_domain_for_degree(len(evals))
log_blowup = 1

commitment, lde = CirclePCS.commit(evals, domain, 1 << log_blowup)

transcript = MerlinTranscript(b'circle pcs')
transcript.append_message(b'commitment', bytes(str(commitment.root), 'ascii'))

query_num = 1

domain = CirclePCS.natural_domain_for_degree(len(lde))
point = CirclePCS.G5[5]
proof = CirclePCS.open(lde, commitment, point, log_blowup, transcript, query_num, True)

transcript = MerlinTranscript(b'circle pcs')
transcript.append_message(b'commitment', bytes(str(commitment.root), 'ascii'))
CirclePCS.verify(commitment.root, domain, log_blowup, point, eval_at_point_raw(evals, CirclePCS.natural_domain_for_degree(len(evals)), point), proof, transcript, True)

if __name__ == '__main__':
unittest.main()
54 changes: 35 additions & 19 deletions tests/test_kzg_hiding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ class TestKZG10Commitment(unittest.TestCase):
def setUp(self):
self.max_degree = 10
self.hiding_bound = 3
self.kzg = KZG10Commitment(DummyGroup(Field), DummyGroup(Field), self.max_degree, hiding_bound=self.hiding_bound, debug=True)
self.kzg = KZG10Commitment(DummyGroup(Field), DummyGroup(Field), debug=True)

def test_setup(self):
params = self.kzg.setup()
params = self.kzg.setup(self.max_degree)
self.assertIn('powers_of_g', params)
self.assertIn('powers_of_gamma_g', params)
self.assertIn('h', params)
Expand All @@ -23,34 +23,46 @@ def test_setup(self):
self.assertEqual(len(params['powers_of_gamma_g']), self.max_degree + 2)

def test_commit(self):
params = self.kzg.setup(self.max_degree)
powers, vk = self.kzg.trim(params, self.max_degree)

poly = UniPolynomial([randint(0, 100) for _ in range(5)])
commitment, random_ints = self.kzg.commit(poly)
commitment, random_ints = self.kzg.commit(powers, poly, self.hiding_bound)
self.assertIsInstance(commitment, Commitment)
self.assertEqual(len(random_ints), self.hiding_bound + 1)

def test_open(self):
params = self.kzg.setup(self.max_degree)
powers, vk = self.kzg.trim(params, self.max_degree)

poly = UniPolynomial([randint(0, 100) for _ in range(5)])
point = randint(0, 100)
commitment, random_ints = self.kzg.commit(poly)
proof = self.kzg.open(poly, point, random_ints)
commitment, random_ints = self.kzg.commit(powers, poly, self.hiding_bound)
proof = self.kzg.open(powers, poly, point, random_ints, True)
self.assertIn('w', proof)
self.assertIn('random_v', proof)

def test_check(self):
params = self.kzg.setup(self.max_degree)
powers, vk = self.kzg.trim(params, self.max_degree)

poly = UniPolynomial([randint(0, 100) for _ in range(5)])
point = randint(0, 100)
commitment, random_ints = self.kzg.commit(poly)
commitment, random_ints = self.kzg.commit(powers, poly, self.hiding_bound)
value = poly.evaluate(point)
proof = self.kzg.open(poly, point, random_ints)
self.assertTrue(self.kzg.check(commitment, point, value, proof))
proof = self.kzg.open(powers, poly, point, random_ints, True)
self.assertTrue(self.kzg.check(vk, commitment, point, value, proof, True))

def test_check_invalid_proof(self):
params = self.kzg.setup(self.max_degree)
powers, vk = self.kzg.trim(params, self.max_degree)

poly = UniPolynomial([randint(0, 100) for _ in range(5)])
point = randint(0, 100)
commitment, random_ints = self.kzg.commit(poly)
commitment, random_ints = self.kzg.commit(powers, poly, self.hiding_bound)
value = poly.evaluate(point)
invalid_proof = self.kzg.open(poly, point + 1, random_ints) # Invalid point
self.assertFalse(self.kzg.check(commitment, point, value, invalid_proof))
invalid_proof = self.kzg.open(powers, poly, point + 1, random_ints, True) # Invalid point
self.assertFalse(self.kzg.check(vk, commitment, point, value, invalid_proof, True))

def test_batch_check(self):
num_polynomials = 5
Expand All @@ -60,35 +72,39 @@ def test_batch_check(self):
commitments = []
values = []
proofs = []
params = self.kzg.setup(self.max_degree)
powers, vk = self.kzg.trim(params, self.max_degree)

for p, point in zip(polynomials, points):
comm, random_ints = self.kzg.commit(p)
comm, random_ints = self.kzg.commit(powers, p, self.hiding_bound)
commitments.append(comm)
values.append(p.evaluate(point))
proofs.append(self.kzg.open(p, point, random_ints))
proofs.append(self.kzg.open(powers, p, point, random_ints, True))

self.assertTrue(self.kzg.batch_check(commitments, points, values, proofs))
self.assertTrue(self.kzg.batch_check(vk, commitments, points, values, proofs, True))

def test_batch_check_invalid_proof(self):
num_polynomials = 5
polynomials = [UniPolynomial([randint(0, 100) for _ in range(randint(5, 10))]) for _ in range(num_polynomials)]
points = [randint(0, 100) for _ in range(num_polynomials)]

params = self.kzg.setup(self.max_degree)
powers, vk = self.kzg.trim(params, self.max_degree)

commitments = []
values = []
proofs = []

for p, point in zip(polynomials, points):
comm, random_ints = self.kzg.commit(p)
comm, random_ints = self.kzg.commit(powers, p, self.hiding_bound)
commitments.append(comm)
values.append(p.evaluate(point))
proofs.append(self.kzg.open(p, point, random_ints))
proofs.append(self.kzg.open(powers, p, point, random_ints, True))

# Invalidate one proof
invalid_index = randint(0, num_polynomials - 1)
proofs[invalid_index] = self.kzg.open(polynomials[invalid_index], points[invalid_index] + 1, random_ints)
proofs[invalid_index] = self.kzg.open(powers, polynomials[invalid_index], points[invalid_index] + 1, random_ints, True)

self.assertFalse(self.kzg.batch_check(commitments, points, values, proofs))
self.assertFalse(self.kzg.batch_check(vk, commitments, points, values, proofs, True))

if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_unipolynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_uni_eval_from_evals(self):
domain = [1, 2, 3, 4]
z = 5
result = UniPolynomial.uni_eval_from_evals(evals, z, domain)
self.assertEqual(result, 125)
self.assertEqual(int(result), 125)

if __name__ == '__main__':
main()

0 comments on commit f01201f

Please sign in to comment.