Skip to content

Commit

Permalink
test(test_circle): Add test cases
Browse files Browse the repository at this point in the history
test fft list, eval at point raw, batch inversion
  • Loading branch information
ahy231 committed Nov 18, 2024
1 parent 2596e43 commit 99b182d
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions tests/test_circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
sys.path.append('src')
sys.path.append('../src')

from circle import CirclePCS, CFFT, F31, FRI, evaluate_at_point, eval_at_point_raw, extract_lambda, twin_cosets, combine, standard_position_cosets, log_2, deep_quotient_reduce, deep_quotient_reduce_raw, g_30, group_mul, group_inv
from circle import CirclePCS, CFFT, F31, FRI, batch_multiplicative_inverse, batch_multiplicative_inverse_raw, evaluate_at_point, eval_at_point_raw, extract_lambda, twin_cosets, combine, standard_position_cosets, log_2, deep_quotient_reduce, deep_quotient_reduce_raw, g_30, group_mul, group_inv
from merlin.merlin_transcript import MerlinTranscript

def fold(lde, domain, alpha=3, fold_y=False):
Expand Down Expand Up @@ -40,17 +40,16 @@ def test_extrapolate(self):
evals = [1, 2, 3, 4]
domain = standard_position_cosets[log_2(len(evals))]
blowup_factor = 2
lde = CFFT.extrapolate(evals, domain, blowup_factor)
lde = CFFT.extrapolate(evals, blowup_factor)

assert len(lde) == len(evals) * blowup_factor, f'len(lde) != len(evals) * blowup_factor, {len(lde)}, {len(evals) * blowup_factor}'
for i, p in enumerate(standard_position_cosets[log_2(len(evals) * blowup_factor)]):
assert eval_at_point_raw(evals, domain, p) == lde[i], f'evaluate_at_point error, {eval_at_point_raw(evals, domain, p)}, {lde[i]}'

def test_fold(self):
evals = [F31(randint(0, 31)) for _ in range(4)]
domain = standard_position_cosets[log_2(len(evals))]
evals = [F31(randint(0, 30)) for _ in range(4)]
blowup_factor = 2
lde = CFFT.extrapolate(evals, domain, blowup_factor)
lde = CFFT.extrapolate(evals, blowup_factor)

domain_lde = standard_position_cosets[log_2(len(evals) * blowup_factor)]
# print('domain_lde:', domain_lde)
Expand All @@ -59,18 +58,42 @@ def test_fold(self):
assert folded_folded[0] == folded_folded[1], f'folded_folded[0] != folded_folded[1], {folded_folded[0]}, {folded_folded[1]}'

def test_fft(self):
evals = [F31(randint(0, 31)) for _ in range(4)]
evals = [F31(randint(0, 30)) for _ in range(4)]
domain = standard_position_cosets[log_2(len(evals))]

coeffs = CFFT.ifft(CFFT.vec_2_poly(evals, domain))
evals_fft = CFFT.poly_2_vec(CFFT.fft(coeffs, domain), domain)
assert evals_fft == evals, f'evals_fft != evals, {evals_fft}, {evals}'

def test_fft_list(self):
evals = [F31(randint(0, 30)) for _ in range(4)]
coeffs = CFFT.ifft_list(evals)
evals_fft = CFFT.fft_list(coeffs)
assert evals_fft == evals, f'evals_fft != evals, {evals_fft}, {evals}'

def test_eval_at_point_raw(self):
evals = [F31(randint(0, 30)) for _ in range(4)]
domain = standard_position_cosets[log_2(len(evals))]
new_domain = standard_position_cosets[log_2(len(evals)) + 1]
point = new_domain[randint(0, len(new_domain) - 1)]
eval_at_p = eval_at_point_raw(evals, domain, point)
coeffs = CFFT.ifft(CFFT.vec_2_poly(evals, domain))
new_coeffs = [F31(0) for _ in range(len(coeffs) * 2)]
for i in range(len(coeffs) * 2):
new_coeffs[i] = coeffs[i // 2] if i % 2 == 0 else F31(0)
new_evals = CFFT.fft(new_coeffs, new_domain)
assert new_evals[point] == eval_at_p, f'new_evals[point] != eval_at_p, {new_evals[point]}, {eval_at_p}'

def test_batch_multiplicative_inverse(self):
evals = [F31(randint(1, 30)) for _ in range(14)]
inv = batch_multiplicative_inverse_raw(evals)
assert inv == batch_multiplicative_inverse(evals), f'inv != batch_multiplicative_inverse, {inv}, {batch_multiplicative_inverse(evals)}'

def test_fri_prove(self):
evals = [F31(randint(0, 31)) for _ in range(4)]
evals = [F31(randint(0, 30)) for _ in range(4)]
domain = standard_position_cosets[log_2(len(evals))]
blowup_factor = 2
lde = CFFT.extrapolate(evals, domain, blowup_factor)
lde = CFFT.extrapolate(evals, blowup_factor)

domain_lde = standard_position_cosets[log_2(len(evals) * blowup_factor)]
# print('domain_lde:', domain_lde)
Expand All @@ -93,14 +116,14 @@ def test_fri_prove(self):
# assert p == expected, f'evaluate_at_point error, {p}, {expected}'

def test_deep_quotient_reduce(self):
evals = [F31(randint(0, 31)) for _ in range(4)]
evals = [F31(randint(0, 30)) for _ in range(4)]
domain = standard_position_cosets[log_2(len(evals))]
alpha = 3
zeta = g_30 ** 3
p_at_zeta = eval_at_point_raw(evals, domain, zeta)

blowup_factor = 2
lde = CFFT.extrapolate(evals, domain, blowup_factor)
lde = CFFT.extrapolate(evals, blowup_factor)
domain_lde = standard_position_cosets[log_2(len(evals) * blowup_factor)]

reduced = deep_quotient_reduce(lde, domain_lde, alpha, zeta, p_at_zeta)
Expand All @@ -113,8 +136,8 @@ def test_deep_quotient_reduce(self):
assert c == F31(0), f'coeffs[{i}] != 0, coeffs: {coeffs}'

def test_extract_lambda(self):
evals = [F31(randint(0, 31)) for _ in range(4)]
lde = CFFT.extrapolate(evals, standard_position_cosets[2], 2)
evals = [F31(randint(0, 30)) for _ in range(4)]
lde = CFFT.extrapolate(evals, 2)
domain = standard_position_cosets[3]

zeta = g_30 ** 3
Expand All @@ -139,7 +162,7 @@ def test_extract_lambda(self):
assert c == F31(0), f'coeffs[{i}] != 0, coeffs: {q_interpolated}'

def test_circle_pcs(self):
evals = [F31(randint(0, 31)) for _ in range(4)]
evals = [F31(randint(0, 30)) for _ in range(4)]
domain = CirclePCS.natural_domain_for_degree(len(evals))
log_blowup = 1

Expand Down

0 comments on commit 99b182d

Please sign in to comment.