Skip to content

Commit

Permalink
feat(fri): refector functions
Browse files Browse the repository at this point in the history
split original 'prove' function into 'commit' and 'prove'
  • Loading branch information
ahy231 committed Nov 26, 2024
1 parent 99b182d commit d3baa83
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 23 deletions.
25 changes: 13 additions & 12 deletions src/fri.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class FRI:
security_level = 128

@classmethod
def prove(cls, evals, rate, point, gen, domain, debug=False):
def commit(cls, evals, rate, domain, debug=False):
if debug: print("evals:", evals)
N = len(evals)
assert is_power_of_two(N)
Expand All @@ -18,13 +18,14 @@ def prove(cls, evals, rate, point, gen, domain, debug=False):
code = cls.rs_encode_single(coeffs, domain, rate)
if debug: print("code:", code)
assert len(code) == N * rate, f"code: {code}, degree_bound: {degree_bound}, rate: {rate}"
val = UniPolynomial.uni_eval_from_evals(evals, point, domain[:N])
if debug: print("val:", val)
assert len(domain) == N * rate, f"domain: {domain}, N: {N}, rate: {rate}"

code_tree = MerkleTree(code)
transcript = MerlinTranscript(b"FRI")
transcript.append_message(b"code", code_tree.root.encode('ascii'))
return MerkleTree(code), code

@classmethod
def prove(cls, code, code_tree, val, point, domain, rate, degree_bound, gen, transcript, debug=False):
if debug: print("val:", val)
assert len(domain) == degree_bound * rate, f"domain: {domain}, degree_bound: {degree_bound}, rate: {rate}"
assert isinstance(transcript, MerlinTranscript), f"transcript: {transcript}"

quotient = [(code[i] - val) / (domain[i] - point) for i in range(len(code))]

Expand Down Expand Up @@ -65,12 +66,12 @@ def prove(cls, evals, rate, point, gen, domain, debug=False):
}

@classmethod
def verify(cls, degree_bound, evals_size, rate, proof, point, value, domain, gen, debug=False):
def verify(cls, degree_bound, rate, proof, point, value, domain, gen, transcript, debug=False):

assert degree_bound >= proof['degree_bound']
degree_bound = proof['degree_bound']

transcript = MerlinTranscript(b"FRI")
assert isinstance(transcript, MerlinTranscript), f"transcript: {transcript}"

code_commitment = proof['code_commitment']
quotient_commitment = proof['quotient_commitment']
Expand All @@ -79,7 +80,7 @@ def verify(cls, degree_bound, evals_size, rate, proof, point, value, domain, gen
transcript.append_message(b"quotient", quotient_commitment.encode('ascii'))
transcript.append_message(b"value at z", str(value).encode('ascii'))

z = from_bytes(transcript.challenge_bytes(b"z", 4)) % (evals_size * rate)
z = from_bytes(transcript.challenge_bytes(b"z", 4)) % (degree_bound * rate)
code_at_z_proof = proof['code_at_z_proof']
quotient_at_z_proof = proof['quotient_at_z_proof']

Expand Down Expand Up @@ -268,9 +269,9 @@ def verify_queries(proof, k, num_vars, num_verifier_queries, T, transcript, debu
if debug: print("x0:", x0)
if debug: print("x1:", x1)

table = T[i]
if debug: print("table:", table)
if i != len(mps) - 1:
table = T[i]
if debug: print("table:", table)
f_code_folded = cur_path[i + 1][0 if x0 < num_vars_copy / 4 else 1]
alpha = fold_challenges[i]
if debug: assert x0 < len(table), f"x0: {x0}, table: {table}"
Expand Down
10 changes: 5 additions & 5 deletions src/unipolynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,27 +644,27 @@ def interpolate(cls, evals, domain):

# barycentric interpolation
@classmethod
def barycentric_weights(cls, D):
def barycentric_weights(cls, D, one=1):
n = len(D)
weights = [1] * n
for i in range(n):
# weights[i] = product([(D[i] - D[j]) if i !=j else Fp(1) for j in range(n)])
for j in range(n):
if i==j:
weights[i] *= 1
weights[i] *= one
continue
weights[i] *= (D[i] - D[j])
weights[i] = 1/weights[i]
weights[i] = one / weights[i]
return weights

@classmethod
def uni_eval_from_evals(cls, evals, z, D):
def uni_eval_from_evals(cls, evals, z, D, one=1):
n = len(evals)
if n != len(D):
raise ValueError("Domain size should be equal to the length of evaluations")
if z in D:
return evals[D.index(z)]
weights = cls.barycentric_weights(D)
weights = cls.barycentric_weights(D, one)
# print("weights={}".format(weights))
e_vec = [weights[i] / (z - D[i]) for i in range(n)]
numerator = sum([e_vec[i] * evals[i] for i in range(n)])
Expand Down
17 changes: 11 additions & 6 deletions tests/test_fri.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def test_fold(self):
coset = Fp.primitive_element() ** (192 // len(evals))
alpha = Fp(7)

evals = FRI.fold(evals, alpha, coset, debug=True)
evals = FRI.fold(evals, alpha, coset, debug=False)
coset = coset ** 2
evals = FRI.fold(evals, alpha, coset, debug=True)
evals = FRI.fold(evals, alpha, coset, debug=False)

assert evals[0] == evals[1] == evals[2] == evals[3]

Expand All @@ -52,6 +52,7 @@ def test_prove(self):
from sage.all import GF
from field import magic
from random import randint
from merlin.merlin_transcript import MerlinTranscript

Fp = magic(GF(193))

Expand All @@ -60,11 +61,15 @@ def test_prove(self):
rate = 4
evals_size = 4
coset = Fp.primitive_element() ** (192 // (evals_size * rate))
point = coset ** randint(evals_size * rate, 192) * Fp.primitive_element()
evals = [randint(0, 193) for i in range(evals_size)]
point = Fp.primitive_element()
evals = [i for i in range(evals_size)]
value = UniPolynomial.uni_eval_from_evals(evals, point, [coset ** i for i in range(len(evals))])
proof = FRI.prove(evals, rate, point, coset, [coset ** i for i in range(evals_size * rate)], debug=False)
FRI.verify(proof['degree_bound'], evals_size, rate, proof, point, value, [coset ** i for i in range(evals_size * rate)], coset, debug=False)
domain = [coset ** i for i in range(evals_size * rate)]
code_tree, code = FRI.commit(evals, rate, domain, debug=False)
transcript = MerlinTranscript(b'test')
transcript.append_message(b"code", code_tree.root.encode('ascii'))
proof = FRI.prove(code, code_tree, value, point, domain, rate, evals_size, coset, transcript, debug=False)
FRI.verify(evals_size, rate, proof, point, value, domain, coset, MerlinTranscript(b'test'), debug=False)


if __name__ == '__main__':
Expand Down

0 comments on commit d3baa83

Please sign in to comment.