Skip to content

Commit

Permalink
fix(circle.py): debug
Browse files Browse the repository at this point in the history
  • Loading branch information
ahy231 committed Nov 13, 2024
1 parent 2304181 commit 9dda07a
Showing 1 changed file with 107 additions and 33 deletions.
140 changes: 107 additions & 33 deletions src/circle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sage.all import *
import sys

sys.path.append("../")
sys.path.append("src")
sys.path.append("../src")

from utils import log_2
Expand All @@ -17,7 +17,25 @@
C31.inject_variables()
I, = C31.gens()

g = 10 + 5 * I
# g = 10 + 5 * I # not a generator
# g = 29 + 11 * I # neither

def test_generator(x, y):
g = x + y * I
g_30 = g**30
for i in range(1, 32):
if g_30**i == 1:
return False
return True

g = 0

for x in range(31):
for y in range(31):
if test_generator(x, y):
g = x + y * I
break

g_30 = g**30

assert g_30**32 == 1
Expand All @@ -29,6 +47,15 @@ def sq(D):
rs += [t**2]
return rs

G5 = [g_30**k for k in range(32)]
G = [G5]
tmp = G[-1]
for i in range(5):
tmp = sq(tmp)
G = [tmp] + G

standard_position_cosets = [[G[i + 1][1] * p for p in G[i]] for i in range(4)]

def group_inv(g1):
x1, y1 = g1
return x1 - y1 * I
Expand Down Expand Up @@ -87,7 +114,7 @@ def compute_lagrange_den_batched(points, at, log_n, debug=False):
numer.append(x + 1)
if debug:
print('y:', y, 'pt:', pt, 's_p_at_p:', s_p_at_p(pt, log_n))
denom.append(y * s_p_at_p(pt, log_n, debug))
denom.append(y * s_p_at_p(pt, log_n, debug=False))

inv_d = batch_multiplicative_inverse(denom)

Expand All @@ -97,20 +124,22 @@ def evaluate_at_point(evals, domain, point, debug=False):
assert len(evals) == len(domain), "len(evals) != len(domain), {} != {}, evals={}, domain={}".format(len(evals), len(domain), evals, domain)
x, _ = point
log_n = log_2(len(evals))
shift = g ** (5 - log_n)
shift = g_30 ** (5 - log_n)
lagrange_num = zeroifier(x, shift, log_n)
lagrange_den = compute_lagrange_den_batched(domain, point, log_n, debug)

return sum([lagrange_den[i] * evals[i] for i in range(len(evals))]) * lagrange_num

def deep_quotient_vanishing_part(x, zeta, alpha_pow_width):
def deep_quotient_vanishing_part(x, zeta, alpha_pow_width, debug=False):
v_p = lambda p, at: (1 - (p - at)[0], -(p - at)[1])
re_v_zeta, im_v_zeta = v_p(x, zeta)
# if debug: print('re_v_zeta:', re_v_zeta, 'im_v_zeta:', im_v_zeta)
return (re_v_zeta - alpha_pow_width * im_v_zeta, re_v_zeta ** 2 + im_v_zeta ** 2)

def deep_quotient_reduce(evals, domain, alpha, zeta, p_at_zeta):
vp_nums, vp_demons = zip(*[(deep_quotient_vanishing_part(x, zeta, alpha)) for x in domain])
def deep_quotient_reduce(evals, domain, alpha, zeta, p_at_zeta, debug=False):
vp_nums, vp_demons = zip(*[(deep_quotient_vanishing_part(x, zeta, alpha, debug)) for x in domain])
vp_denom_invs = batch_multiplicative_inverse(vp_demons)
if debug: print('vp_nums:', vp_nums, 'vp_denom_invs:', vp_denom_invs, 'p_at_zeta:', p_at_zeta, 'evals:', evals)

return [vp_nums[i] * vp_denom_invs[i] * (-p_at_zeta + evals[i]) for i in range(len(evals))]

Expand All @@ -128,7 +157,7 @@ def extract_lambda(lde, log_blowup, debug=False):
if debug: print('CirclePCS.domains[log_lde_size][:1 << log_blowup]:', CirclePCS.domains[log_lde_size][:1 << log_blowup])
v_d_init = [v_n(p[0], log_lde_size - log_blowup) for p in CirclePCS.domains[log_lde_size][:1 << log_blowup]]

v_d = v_d_init + v_d_init[:-1]
v_d = v_d_init + v_d_init[::-1]
while (len(v_d) < len(lde)):
v_d += v_d

Expand All @@ -145,6 +174,47 @@ def extract_lambda(lde, log_blowup, debug=False):

return new_lde, lambda_

def twin_cosets(n, size):
k = log_2(size * n)
log_size = log_2(size)
G_size_over_2 = G[log_size - 1]

shifts = [standard_position_cosets[k][i] for i in range(size * n // 4)]
shifts_inv = [group_inv(shifts[i]) for i in range(size * n // 4)]
coset_1 = [[G_size_over_2[j] * shifts[i] for j in range(size // 2)] for i in range(n)]
coset_2 = [[G_size_over_2[(j + 1) % (size // 2)] * shifts_inv[i] for j in range(size // 2)] for i in range(n)]
res = []
for i in range(n):
c1 = coset_1[i]
c2 = coset_2[i]
tmp = zip(c1, c2)
res += [[x for y in tmp for x in list(y)]]
return res

def pop(v):
assert len(v) > 0, "v is empty"
return v[0], v[1:]

def combine(cosets):
cosets = cosets[:]
n = len(cosets)
res = []
while len(cosets[0]) > 0:
for i in range(n):
t, cosets[i] = pop(cosets[i])
res += [t]
for i in range(n):
t, cosets[n - 1 - i] = pop(cosets[n - 1 - i])
res += [t]
return res

# test twin_cosets
tcs = twin_cosets(2, 4)
for tc in tcs:
for t in tc:
assert t in standard_position_cosets[log_2(8)]
assert combine(tcs) == standard_position_cosets[log_2(8)], f'combine error, {combine(tcs)}, {standard_position_cosets[log_2(8)]}'

class CFFT:
@classmethod
def _ifft_first_step(cls, f):
Expand Down Expand Up @@ -174,8 +244,8 @@ def _ifft_normal_step(cls, f):

for x in f:
assert x != 0, "f should be on coset"
f0[pi(x)] = (f[x] + f[-x]) / C31(2)
f1[pi(x)] = (f[x] - f[-x]) / (C31(2) * x)
f0[pi(x)] = (f[x] + f[-x]) / F31(2)
f1[pi(x)] = (f[x] - f[-x]) / (F31(2) * x)

# Check that f is divided into 2 parts correctly
assert f[x] == f0[pi(x)] + x * f1[pi(x)]
Expand All @@ -201,11 +271,7 @@ def fft_first_step(cls, f, D):
f1 = f[len_f//2:]

# halve the domain by simply removing the y coordinate
D_new = []
for t in D:
x, _ = t
if x not in D_new:
D_new.append(x)
D_new = [p[0] for p in D[:len(D)//2]]

# Check that the new domain is exactly half size of the old domain
assert len(D_new) * 2 == len(D), "len(D_new) * 2 != len(D), {} * 2 != {}, D_new={}, D={}".format(len(D_new), len(D), D_new, D)
Expand Down Expand Up @@ -282,11 +348,15 @@ def poly_2_vec(cls, poly):
return [poly[t] for t in poly]

@classmethod
def extrapolate(cls, evals, blowup_factor):
new_domain = CirclePCS.natural_domain_for_degree(len(evals) * blowup_factor)
coeffs = CFFT.ifft(evals)
coeffs += [0] * (len(evals) * (blowup_factor - 1))
return CFFT.fft(coeffs, new_domain)
def extrapolate(cls, evals, domain, blowup_factor):
evals = cls.vec_2_poly(evals, domain)
coeffs = cls.ifft(evals)
cosets = twin_cosets(blowup_factor, len(evals))
res = []
for coset in cosets:
res += [cls.fft(coeffs, coset)]
res = [cls.poly_2_vec(x) for x in res]
return combine(res)

class FRI:
@classmethod
Expand All @@ -307,7 +377,7 @@ def fold_y(cls, evals, domain, beta, debug=False):
if debug: print('fold y')
if debug: print(f"f0 = (({left[i]}) + ({right[i]}))/2 = {f0}")
if debug: print(f"f1 = (({left[i]}) - ({right[i]}))/(2 * {y}) = {f1}")
if debug: print(f"f0 + {beta} * f1 = ({f0}) + {beta} * ({f1}) = {f0 + group_mul(beta, f1)}")
if debug: print(f"f0 + ({beta}) * f1 = ({f0}) + ({beta}) * ({f1}) = {f0 + group_mul(beta, f1)}")

return evals

Expand All @@ -318,7 +388,7 @@ def fold_y_row(cls, y, beta, left, right, debug=False):
if debug: print('fold y row')
if debug: print(f"f0 = (({left}) + ({right}))/2 = {f0}")
if debug: print(f"f1 = (({left}) - ({right}))/(2 * {y}) = {f1}")
if debug: print(f"f0 + {beta} * f1 = ({f0}) + {beta} * ({f1}) = {f0 + group_mul(beta, f1)}")
if debug: print(f"f0 + ({beta}) * f1 = ({f0}) + ({beta}) * ({f1}) = {f0 + group_mul(beta, f1)}")
return f0 + group_mul(beta, f1)

# Inputs:
Expand Down Expand Up @@ -521,8 +591,8 @@ def verify(cls, proof, transcript, open_input, debug=False):
assert folded_eval == proof["final_poly"], "folded_eval != proof['final_poly'], {} != {}".format(folded_eval, proof["final_poly"])

class CirclePCS:
G5 = [g**k for k in range(32)]
G4_standard = [g * t for t in sq(G5)]
G5 = [g_30**k for k in range(32)]
G4_standard = [g_30 * t for t in sq(G5)]
G3_standard = sq(G4_standard)
G2_standard = sq(G3_standard)
G1_standard = sq(G2_standard)
Expand All @@ -543,24 +613,24 @@ def commit(cls, eval, domain, blowup_factor):
if log_n + log_2(blowup_factor) > 4:
raise ValueError("Eval too long")

eval_poly = CFFT.vec_2_poly(eval, domain)
lde = CFFT.extrapolate(eval_poly, blowup_factor)
lde = CFFT.poly_2_vec(lde)
lde = CFFT.extrapolate(eval, domain, blowup_factor)
return MerkleTree(lde), lde

@classmethod
def open(cls, evals, evals_commit, zeta, log_blowup, transcript, num_queries, debug=False):
if debug: print('evals:', evals)
assert isinstance(transcript, MerlinTranscript), "transcript should be a MerlinTranscript"
alpha = int.from_bytes(transcript.challenge_bytes(b"alpha", 4), "big")
if debug: print('alpha:', alpha)

# evaluate the polynomial at the point zeta
domain = cls.natural_domain_for_degree(len(evals))
p_at_zeta = evaluate_at_point(evals, domain, zeta, debug)
if debug: print('p_at_zeta:', p_at_zeta)

# deep quotient
reduced_opening = deep_quotient_reduce(evals, domain, alpha, zeta, p_at_zeta)

reduced_opening = deep_quotient_reduce(evals, domain, alpha, zeta, p_at_zeta, debug)
if debug: print('reduced_opening:', reduced_opening)
# extract lambda
first_layer, lambda_ = extract_lambda(reduced_opening, log_blowup, debug)
if debug: print('first_layer:', first_layer, ', lambda_:', lambda_)
Expand Down Expand Up @@ -645,7 +715,8 @@ def open_input(index, input_proof):

left = lambda_corrected if index_shifted < len(domain) // 2 else first_layer_sibling_value
right = first_layer_sibling_value if index_shifted < len(domain) // 2 else lambda_corrected
_, y = domain[index_shifted]
index_sibling = len(domain) - 1 - index_shifted
_, y = domain[min(index_shifted, index_sibling)]
fri_input = FRI.fold_y_row(y, bivariate_beta, left, right, debug)
fri_input = fri_input[0]
if debug: print('fri_input:', fri_input)
Expand All @@ -662,11 +733,14 @@ def open_input(index, input_proof):
if __name__ == "__main__":
from random import randint
rand_ext = lambda: randint(0, 31) + randint(0, 31) * I
evals = [rand_ext() for _ in range(4)]
# evals = [rand_ext() for _ in range(2)]
evals = [F31(1), F31(2), F31(3), F31(4)]
domain = CirclePCS.natural_domain_for_degree(len(evals))
log_blowup = 0
print('domain:', domain)
log_blowup = 1

commitment, lde = CirclePCS.commit(evals, domain, 1 << log_blowup)
print('lde:', lde)

transcript = MerlinTranscript(b'circle pcs')
transcript.append_message(b'commitment', bytes(str(commitment.root), 'ascii'))
Expand All @@ -679,4 +753,4 @@ def open_input(index, input_proof):

transcript = MerlinTranscript(b'circle pcs')
transcript.append_message(b'commitment', bytes(str(commitment.root), 'ascii'))
CirclePCS.verify(commitment.root, domain, log_blowup, point, evaluate_at_point(evals, domain, point), proof, transcript, True)
CirclePCS.verify(commitment.root, domain, log_blowup, point, evaluate_at_point(evals, CirclePCS.natural_domain_for_degree(len(evals)), point), proof, transcript, True)

0 comments on commit 9dda07a

Please sign in to comment.