Skip to content

Commit

Permalink
test(circle.py): test for circle
Browse files Browse the repository at this point in the history
  • Loading branch information
ahy231 committed Nov 14, 2024
1 parent 9dda07a commit 15ec106
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 18 deletions.
55 changes: 37 additions & 18 deletions src/circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,23 +130,50 @@ def evaluate_at_point(evals, domain, point, debug=False):

return sum([lagrange_den[i] * evals[i] for i in range(len(evals))]) * lagrange_num

def eval_at_p_recursive(evals, twiddle, debug=False):
if len(evals) == 1:
return evals[0]
else:
f0 = eval_at_p_recursive(evals[:len(evals)//2], pi(twiddle), debug)
f1 = eval_at_p_recursive(evals[len(evals)//2:], pi(twiddle), debug)
return f0 + f1 * twiddle

def eval_at_point_raw(evals, domain, point, debug=False):

x, y = point
poly = CFFT.vec_2_poly(evals, domain)
coeffs = CFFT.ifft(poly)

left, right = coeffs[:len(coeffs)//2], coeffs[len(coeffs)//2:]
left_eval = eval_at_p_recursive(left, x)
right_eval = eval_at_p_recursive(right, x)

return left_eval + right_eval * y

def deep_quotient_vanishing_part(x, zeta, alpha_pow_width, debug=False):
v_p = lambda p, at: (1 - (p - at)[0], -(p - at)[1])
v_p = lambda p, at: (1 - group_mul(p, group_inv(at))[0], -group_mul(p, group_inv(at))[1])
re_v_zeta, im_v_zeta = v_p(x, zeta)
# if debug: print('re_v_zeta:', re_v_zeta, 'im_v_zeta:', im_v_zeta)
return (re_v_zeta - alpha_pow_width * im_v_zeta, re_v_zeta ** 2 + im_v_zeta ** 2)
# return (re_v_zeta - alpha_pow_width * im_v_zeta, re_v_zeta ** 2 + im_v_zeta ** 2)
return (re_v_zeta - im_v_zeta * alpha_pow_width, re_v_zeta ** 2 + im_v_zeta ** 2)

def deep_quotient_reduce(evals, domain, alpha, zeta, p_at_zeta, debug=False):
vp_nums, vp_demons = zip(*[(deep_quotient_vanishing_part(x, zeta, alpha, debug)) for x in domain])
vp_denom_invs = batch_multiplicative_inverse(vp_demons)
if debug: print('vp_nums:', vp_nums, 'vp_denom_invs:', vp_denom_invs, 'p_at_zeta:', p_at_zeta, 'evals:', evals)

return [vp_nums[i] * vp_denom_invs[i] * (-p_at_zeta + evals[i]) for i in range(len(evals))]
return [vp_denom_invs[i] * vp_nums[i] * group_mul(group_inv(p_at_zeta), evals[i]) for i in range(len(evals))]

def deep_quotient_reduce_row(alpha, x, zeta, ps_at_x, ps_at_zeta, debug=False):
vp_num, vp_denom = deep_quotient_vanishing_part(x, zeta, alpha)
if debug: print('vp_num:', vp_num, 'vp_denom:', vp_denom, 'ps_at_x:', ps_at_x, 'ps_at_zeta:', ps_at_zeta)
return vp_num * (-ps_at_zeta + ps_at_x) / vp_denom
return vp_num * group_mul(group_inv(ps_at_zeta), ps_at_x) / vp_denom

def deep_quotient_reduce_raw(evals, domain, alpha, zeta, p_at_zeta, debug=False):
res = []
for ps_at_x, x in zip(evals, domain):
res.append(deep_quotient_reduce_row(alpha, x, zeta, ps_at_x, p_at_zeta, debug))
return res

def extract_lambda(lde, log_blowup, debug=False):
if debug:
Expand Down Expand Up @@ -208,13 +235,6 @@ def combine(cosets):
res += [t]
return res

# test twin_cosets
tcs = twin_cosets(2, 4)
for tc in tcs:
for t in tc:
assert t in standard_position_cosets[log_2(8)]
assert combine(tcs) == standard_position_cosets[log_2(8)], f'combine error, {combine(tcs)}, {standard_position_cosets[log_2(8)]}'

class CFFT:
@classmethod
def _ifft_first_step(cls, f):
Expand Down Expand Up @@ -344,8 +364,8 @@ def vec_2_poly(cls, vec, domain):
return f

@classmethod
def poly_2_vec(cls, poly):
return [poly[t] for t in poly]
def poly_2_vec(cls, poly, domain):
return [poly[t] for t in domain]

@classmethod
def extrapolate(cls, evals, domain, blowup_factor):
Expand All @@ -354,8 +374,7 @@ def extrapolate(cls, evals, domain, blowup_factor):
cosets = twin_cosets(blowup_factor, len(evals))
res = []
for coset in cosets:
res += [cls.fft(coeffs, coset)]
res = [cls.poly_2_vec(x) for x in res]
res += [cls.poly_2_vec(cls.fft(coeffs, coset), coset)]
return combine(res)

class FRI:
Expand Down Expand Up @@ -625,11 +644,11 @@ def open(cls, evals, evals_commit, zeta, log_blowup, transcript, num_queries, de

# evaluate the polynomial at the point zeta
domain = cls.natural_domain_for_degree(len(evals))
p_at_zeta = evaluate_at_point(evals, domain, zeta, debug)
p_at_zeta = eval_at_point_raw(evals, domain, zeta, debug)
if debug: print('p_at_zeta:', p_at_zeta)

# deep quotient
reduced_opening = deep_quotient_reduce(evals, domain, alpha, zeta, p_at_zeta, debug)
reduced_opening = deep_quotient_reduce_raw(evals, domain, alpha, zeta, p_at_zeta, debug)
if debug: print('reduced_opening:', reduced_opening)
# extract lambda
first_layer, lambda_ = extract_lambda(reduced_opening, log_blowup, debug)
Expand Down Expand Up @@ -753,4 +772,4 @@ def open_input(index, input_proof):

transcript = MerlinTranscript(b'circle pcs')
transcript.append_message(b'commitment', bytes(str(commitment.root), 'ascii'))
CirclePCS.verify(commitment.root, domain, log_blowup, point, evaluate_at_point(evals, CirclePCS.natural_domain_for_degree(len(evals)), point), proof, transcript, True)
CirclePCS.verify(commitment.root, domain, log_blowup, point, eval_at_point_raw(evals, CirclePCS.natural_domain_for_degree(len(evals)), point), proof, transcript, True)
92 changes: 92 additions & 0 deletions tests/test_circle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import unittest
from sage.all import *
import sys

sys.path.append('src')
sys.path.append('../src')

from circle import CFFT, F31, C31, FRI, eval_at_point_raw, twin_cosets, combine, standard_position_cosets, log_2, deep_quotient_reduce, deep_quotient_reduce_raw, g_30
from merlin.merlin_transcript import MerlinTranscript

def fold(lde, domain, chunk_size, fold_y=False):
if fold_y:
assert len(domain) == len(lde), f'len(domain) != len(lde), {len(domain)}, {len(lde)}'
else:
assert len(domain) == len(lde) * 2, f'len(domain) != len(lde) * 2, {len(domain)}, {len(lde) * 2}'
res = []
for j in range(len(lde) // chunk_size):
for i in range(chunk_size // 2):
left = lde[j * chunk_size + i]
right = lde[(j + 1) * chunk_size - i - 1]
t = domain[i][1 if fold_y else 0]
# print('t:', t)
f0 = (left + right) / F31(2)
f1 = (left - right) / (F31(2) * t)
assert f0 + f1 * t == left
assert f0 - f1 * t == right
res += [f0 + f1 * 3]
return res

class TestCircle(unittest.TestCase):
def test_twin_cosets(self):
# test twin_cosets
tcs = twin_cosets(2, 4)
for tc in tcs:
for t in tc:
assert t in standard_position_cosets[log_2(8)]
assert combine(tcs) == standard_position_cosets[log_2(8)], f'combine error, {combine(tcs)}, {standard_position_cosets[log_2(8)]}'

def test_extrapolate(self):
evals = [1, 2, 3, 4]
domain = standard_position_cosets[log_2(len(evals))]
blowup_factor = 2
lde = CFFT.extrapolate(evals, domain, blowup_factor)

assert len(lde) == len(evals) * blowup_factor, f'len(lde) != len(evals) * blowup_factor, {len(lde)}, {len(evals) * blowup_factor}'
for i, p in enumerate(standard_position_cosets[log_2(len(evals) * blowup_factor)]):
assert eval_at_point_raw(evals, domain, p) == lde[i], f'evaluate_at_point error, {eval_at_point_raw(evals, domain, p)}, {lde[i]}'

def test_fold(self):
evals = [1, 2, 3, 4]
domain = standard_position_cosets[log_2(len(evals))]
blowup_factor = 2
lde = CFFT.extrapolate(evals, domain, blowup_factor)

domain_lde = standard_position_cosets[log_2(len(evals) * blowup_factor)]
# print('domain_lde:', domain_lde)
folded = fold(lde, domain_lde, len(lde), fold_y=True)
folded_folded = fold(folded, domain_lde, len(lde) // 2, fold_y=False)
assert folded_folded[0] == folded_folded[1], f'folded_folded[0] != folded_folded[1], {folded_folded[0]}, {folded_folded[1]}'

def test_fri_prove(self):
evals = [1, 2, 3, 4]
domain = standard_position_cosets[log_2(len(evals))]
blowup_factor = 2
lde = CFFT.extrapolate(evals, domain, blowup_factor)

domain_lde = standard_position_cosets[log_2(len(evals) * blowup_factor)]
# print('domain_lde:', domain_lde)
folded = fold(lde, domain_lde, len(lde), fold_y=True)

f0 = [lde[0]] + lde[3:5] + [lde[7]]
f1 = lde[1:3] + lde[5:7]

tcs = twin_cosets(2, 4)
assert CFFT.ifft(CFFT.vec_2_poly(f0, tcs[0])) == CFFT.ifft(CFFT.vec_2_poly(f1, tcs[1]))

transcript = MerlinTranscript(b'TEST')
_fri_proof = FRI.prove(folded, blowup_factor, [x[0] for x in domain_lde[:len(folded)]], transcript, lambda x: None, 1)

def test_deep_quotient_reduce(self):
evals = [C31(1), C31(2), C31(3), C31(4)]
domain = standard_position_cosets[log_2(len(evals))]
alpha = 3
zeta = g_30 ** 6
p_at_zeta = eval_at_point_raw(evals, domain, zeta)

reduced = deep_quotient_reduce(evals, domain, alpha, zeta, p_at_zeta)
expected = deep_quotient_reduce_raw(evals, domain, alpha, zeta, p_at_zeta)
assert reduced == expected, f'deep_quotient_reduce error, {reduced}, {expected}'

if __name__ == '__main__':
unittest.main()

0 comments on commit 15ec106

Please sign in to comment.