Skip to content

Commit

Permalink
🦄 refactor(fri): acording to boojum
Browse files Browse the repository at this point in the history
Modified implementation of FRI according to boojum. Actually, value at z is not sent to verifier before generating random number in the last version, it's unsecure.
  • Loading branch information
ahy231 committed Oct 23, 2024
1 parent acb34ab commit bbd2fe8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
33 changes: 17 additions & 16 deletions src/fri.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
from merkle import MerkleTree, verify_decommitment
from merlin.merlin_transcript import MerlinTranscript
from utils import from_bytes, log_2, is_power_of_two, next_power_of_two
from utils import from_bytes, log_2, is_power_of_two
from unipolynomial import UniPolynomial

class FRI:
security_level = 128

@classmethod
def open(cls, evals, rate, point, gen, domain, debug=False):
def prove(cls, evals, rate, point, gen, domain, debug=False):
if debug: print("evals:", evals)
N = len(evals)
degree_bound = next_power_of_two(N - 1) if not is_power_of_two(N - 1) else N - 1
assert is_power_of_two(N)
degree_bound = N
if debug: print("degree_bound:", degree_bound)
coeffs = UniPolynomial.compute_coeffs_from_evals_fast(evals, domain[:N])
if debug: print("coeffs:", coeffs)
val = UniPolynomial.uni_eval_from_evals(evals, point, domain[:N])
if debug: print("val:", val)
assert len(domain) == N * rate, f"domain: {domain}, N: {N}, rate: {rate}"
code = cls.rs_encode_single(coeffs, domain, rate)
if debug: print("code:", code)
assert len(code) == N * rate, f"code: {code}, degree_bound: {degree_bound}, rate: {rate}"
val = UniPolynomial.uni_eval_from_evals(evals, point, domain[:N])
if debug: print("val:", val)
assert len(domain) == N * rate, f"domain: {domain}, N: {N}, rate: {rate}"

code_tree = MerkleTree(code)
transcript = MerlinTranscript(b"FRI.open")
transcript = MerlinTranscript(b"FRI")
transcript.append_message(b"code", code_tree.root.encode('ascii'))

quotient = [(code[i] - val) / (domain[i] - point) for i in range(len(code))]

quotient_tree = MerkleTree(quotient)
transcript.append_message(b"quotient", quotient_tree.root.encode('ascii'))
transcript.append_message(b"value at z", str(val).encode('ascii'))

z = from_bytes(transcript.challenge_bytes(b"z", 4)) % len(code)
code_at_z_proof = code_tree.get_authentication_path(z)
Expand All @@ -49,7 +51,7 @@ def open(cls, evals, rate, point, gen, domain, debug=False):
if cls.security_level % log_2(rate) != 0:
num_verifier_queries += 1

low_degree_proof = cls.prove_low_degree(quotient, rate, degree_bound, gen, num_verifier_queries, debug)
low_degree_proof = cls.prove_low_degree(quotient, rate, degree_bound, gen, num_verifier_queries, transcript, debug)

return {
'low_degree_proof': low_degree_proof,
Expand All @@ -68,13 +70,14 @@ def verify(cls, degree_bound, evals_size, rate, proof, point, value, domain, gen
assert degree_bound >= proof['degree_bound']
degree_bound = proof['degree_bound']

transcript = MerlinTranscript(b"FRI.open")
transcript = MerlinTranscript(b"FRI")

code_commitment = proof['code_commitment']
quotient_commitment = proof['quotient_commitment']

transcript.append_message(b"code", code_commitment.encode('ascii'))
transcript.append_message(b"quotient", quotient_commitment.encode('ascii'))
transcript.append_message(b"value at z", str(value).encode('ascii'))

z = from_bytes(transcript.challenge_bytes(b"z", 4)) % (evals_size * rate)
code_at_z_proof = proof['code_at_z_proof']
Expand All @@ -96,15 +99,14 @@ def verify(cls, degree_bound, evals_size, rate, proof, point, value, domain, gen
if cls.security_level % log_2(rate) != 0:
num_verifier_queries += 1

cls.verify_low_degree(degree_bound, rate, proof['low_degree_proof'], gen, num_verifier_queries, debug)
cls.verify_low_degree(degree_bound, rate, proof['low_degree_proof'], gen, num_verifier_queries, transcript, debug)

@staticmethod
def prove_low_degree(evals, rate, degree_bound, gen, num_verifier_queries, debug=False):
def prove_low_degree(evals, rate, degree_bound, gen, num_verifier_queries, transcript, debug=False):
assert is_power_of_two(degree_bound)

first_tree = MerkleTree(evals)
evals_copy = evals
transcript = MerlinTranscript(b"FRI")
transcript.append_message(b"first_oracle", first_tree.root.encode('ascii'))

alpha = transcript.challenge_bytes(b"alpha", 4)
Expand Down Expand Up @@ -171,12 +173,12 @@ def fold(evals, alpha, g, debug=False):


@staticmethod
def verify_low_degree(degree_bound, rate, proof, gen, num_verifier_queries, debug=False):
def verify_low_degree(degree_bound, rate, proof, gen, num_verifier_queries, transcript, debug=False):
log_degree_bound = log_2(degree_bound)
log_evals = log_2(degree_bound * rate)
T = [[(gen**(2 ** j)) ** i for i in range(2 ** (log_evals - j - 1))] for j in range(0, log_evals)]
if debug: print("T:", T)
FRI.verify_queries(proof, log_degree_bound, degree_bound * rate, num_verifier_queries, T, debug)
FRI.verify_queries(proof, log_degree_bound, degree_bound * rate, num_verifier_queries, T, transcript, debug)

@staticmethod
def query_phase(transcript: MerlinTranscript, first_tree: MerkleTree, first_oracle, trees: list, oracles: list, num_vars, num_verifier_queries, debug=False):
Expand Down Expand Up @@ -238,8 +240,7 @@ def query_phase(transcript: MerlinTranscript, first_tree: MerkleTree, first_orac
return query_paths, merkle_paths

@staticmethod
def verify_queries(proof, k, num_vars, num_verifier_queries, T, debug=False):
transcript = MerlinTranscript(b"FRI")
def verify_queries(proof, k, num_vars, num_verifier_queries, T, transcript, debug=False):
transcript.append_message(b"first_oracle", bytes(proof['first_oracle'], 'ascii'))
alpha = transcript.challenge_bytes(b"alpha", 4)
alpha = from_bytes(alpha)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_fri.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ def test_low_degree(self):
from sage.all import GF
from field import magic
from random import randint
from merlin.merlin_transcript import MerlinTranscript

Fp = magic(GF(193))

assert Fp.primitive_element() ** 192 == 1

degree_bound = 8
blow_up_factor = 4
blow_up_factor = 2
num_verifier_queries = 8

assert is_power_of_two(degree_bound)

evals = FRI.rs_encode_single([randint(0, 193) for _ in range(degree_bound)], [Fp.primitive_element() ** (i * 192 // (degree_bound * 2 ** blow_up_factor)) for i in range(degree_bound * 2 ** blow_up_factor)], 2 ** blow_up_factor)
proof = FRI.prove_low_degree(evals, 2 ** blow_up_factor, degree_bound, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, debug=False)
FRI.verify_low_degree(degree_bound, 2 ** blow_up_factor, proof, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, debug=False)
proof = FRI.prove_low_degree(evals, 2 ** blow_up_factor, degree_bound, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, MerlinTranscript(b'test'), debug=False)
FRI.verify_low_degree(degree_bound, 2 ** blow_up_factor, proof, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, MerlinTranscript(b'test'), debug=False)

def test_open(self):
def test_prove(self):
from sage.all import GF
from field import magic
from random import randint
Expand All @@ -63,7 +63,7 @@ def test_open(self):
point = coset ** 0 * Fp.primitive_element()
evals = [randint(0, 193) for i in range(evals_size)]
value = UniPolynomial.uni_eval_from_evals(evals, point, [coset ** i for i in range(len(evals))])
proof = FRI.open(evals, rate, point, coset, [coset ** i for i in range(evals_size * rate)], debug=False)
proof = FRI.prove(evals, rate, point, coset, [coset ** i for i in range(evals_size * rate)], debug=False)
FRI.verify(proof['degree_bound'], evals_size, rate, proof, point, value, [coset ** i for i in range(evals_size * rate)], coset, debug=False)


Expand Down

0 comments on commit bbd2fe8

Please sign in to comment.