From bbd2fe87bc9cce76c4696928bccdbc65fcc98223 Mon Sep 17 00:00:00 2001 From: Neil Date: Wed, 23 Oct 2024 17:14:15 +0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=A6=84=20refactor(fri):=20acording=20to?= =?UTF-8?q?=20boojum?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Modified implementation of FRI according to boojum. Actually, value at z is not sent to verifier before generating random number in the last version, it's unsecure. --- src/fri.py | 33 +++++++++++++++++---------------- tests/test_fri.py | 12 ++++++------ 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/fri.py b/src/fri.py index 7acdbdc..23b47fe 100644 --- a/src/fri.py +++ b/src/fri.py @@ -1,34 +1,36 @@ from merkle import MerkleTree, verify_decommitment from merlin.merlin_transcript import MerlinTranscript -from utils import from_bytes, log_2, is_power_of_two, next_power_of_two +from utils import from_bytes, log_2, is_power_of_two from unipolynomial import UniPolynomial class FRI: security_level = 128 @classmethod - def open(cls, evals, rate, point, gen, domain, debug=False): + def prove(cls, evals, rate, point, gen, domain, debug=False): if debug: print("evals:", evals) N = len(evals) - degree_bound = next_power_of_two(N - 1) if not is_power_of_two(N - 1) else N - 1 + assert is_power_of_two(N) + degree_bound = N if debug: print("degree_bound:", degree_bound) coeffs = UniPolynomial.compute_coeffs_from_evals_fast(evals, domain[:N]) if debug: print("coeffs:", coeffs) - val = UniPolynomial.uni_eval_from_evals(evals, point, domain[:N]) - if debug: print("val:", val) - assert len(domain) == N * rate, f"domain: {domain}, N: {N}, rate: {rate}" code = cls.rs_encode_single(coeffs, domain, rate) if debug: print("code:", code) assert len(code) == N * rate, f"code: {code}, degree_bound: {degree_bound}, rate: {rate}" + val = UniPolynomial.uni_eval_from_evals(evals, point, domain[:N]) + if debug: print("val:", val) + assert len(domain) == N * rate, f"domain: {domain}, N: {N}, rate: {rate}" code_tree = MerkleTree(code) - transcript = MerlinTranscript(b"FRI.open") + transcript = MerlinTranscript(b"FRI") transcript.append_message(b"code", code_tree.root.encode('ascii')) quotient = [(code[i] - val) / (domain[i] - point) for i in range(len(code))] quotient_tree = MerkleTree(quotient) transcript.append_message(b"quotient", quotient_tree.root.encode('ascii')) + transcript.append_message(b"value at z", str(val).encode('ascii')) z = from_bytes(transcript.challenge_bytes(b"z", 4)) % len(code) code_at_z_proof = code_tree.get_authentication_path(z) @@ -49,7 +51,7 @@ def open(cls, evals, rate, point, gen, domain, debug=False): if cls.security_level % log_2(rate) != 0: num_verifier_queries += 1 - low_degree_proof = cls.prove_low_degree(quotient, rate, degree_bound, gen, num_verifier_queries, debug) + low_degree_proof = cls.prove_low_degree(quotient, rate, degree_bound, gen, num_verifier_queries, transcript, debug) return { 'low_degree_proof': low_degree_proof, @@ -68,13 +70,14 @@ def verify(cls, degree_bound, evals_size, rate, proof, point, value, domain, gen assert degree_bound >= proof['degree_bound'] degree_bound = proof['degree_bound'] - transcript = MerlinTranscript(b"FRI.open") + transcript = MerlinTranscript(b"FRI") code_commitment = proof['code_commitment'] quotient_commitment = proof['quotient_commitment'] transcript.append_message(b"code", code_commitment.encode('ascii')) transcript.append_message(b"quotient", quotient_commitment.encode('ascii')) + transcript.append_message(b"value at z", str(value).encode('ascii')) z = from_bytes(transcript.challenge_bytes(b"z", 4)) % (evals_size * rate) code_at_z_proof = proof['code_at_z_proof'] @@ -96,15 +99,14 @@ def verify(cls, degree_bound, evals_size, rate, proof, point, value, domain, gen if cls.security_level % log_2(rate) != 0: num_verifier_queries += 1 - cls.verify_low_degree(degree_bound, rate, proof['low_degree_proof'], gen, num_verifier_queries, debug) + cls.verify_low_degree(degree_bound, rate, proof['low_degree_proof'], gen, num_verifier_queries, transcript, debug) @staticmethod - def prove_low_degree(evals, rate, degree_bound, gen, num_verifier_queries, debug=False): + def prove_low_degree(evals, rate, degree_bound, gen, num_verifier_queries, transcript, debug=False): assert is_power_of_two(degree_bound) first_tree = MerkleTree(evals) evals_copy = evals - transcript = MerlinTranscript(b"FRI") transcript.append_message(b"first_oracle", first_tree.root.encode('ascii')) alpha = transcript.challenge_bytes(b"alpha", 4) @@ -171,12 +173,12 @@ def fold(evals, alpha, g, debug=False): @staticmethod - def verify_low_degree(degree_bound, rate, proof, gen, num_verifier_queries, debug=False): + def verify_low_degree(degree_bound, rate, proof, gen, num_verifier_queries, transcript, debug=False): log_degree_bound = log_2(degree_bound) log_evals = log_2(degree_bound * rate) T = [[(gen**(2 ** j)) ** i for i in range(2 ** (log_evals - j - 1))] for j in range(0, log_evals)] if debug: print("T:", T) - FRI.verify_queries(proof, log_degree_bound, degree_bound * rate, num_verifier_queries, T, debug) + FRI.verify_queries(proof, log_degree_bound, degree_bound * rate, num_verifier_queries, T, transcript, debug) @staticmethod def query_phase(transcript: MerlinTranscript, first_tree: MerkleTree, first_oracle, trees: list, oracles: list, num_vars, num_verifier_queries, debug=False): @@ -238,8 +240,7 @@ def query_phase(transcript: MerlinTranscript, first_tree: MerkleTree, first_orac return query_paths, merkle_paths @staticmethod - def verify_queries(proof, k, num_vars, num_verifier_queries, T, debug=False): - transcript = MerlinTranscript(b"FRI") + def verify_queries(proof, k, num_vars, num_verifier_queries, T, transcript, debug=False): transcript.append_message(b"first_oracle", bytes(proof['first_oracle'], 'ascii')) alpha = transcript.challenge_bytes(b"alpha", 4) alpha = from_bytes(alpha) diff --git a/tests/test_fri.py b/tests/test_fri.py index 8cacbd4..cf9af06 100644 --- a/tests/test_fri.py +++ b/tests/test_fri.py @@ -33,22 +33,22 @@ def test_low_degree(self): from sage.all import GF from field import magic from random import randint + from merlin.merlin_transcript import MerlinTranscript Fp = magic(GF(193)) assert Fp.primitive_element() ** 192 == 1 degree_bound = 8 - blow_up_factor = 4 + blow_up_factor = 2 num_verifier_queries = 8 - assert is_power_of_two(degree_bound) evals = FRI.rs_encode_single([randint(0, 193) for _ in range(degree_bound)], [Fp.primitive_element() ** (i * 192 // (degree_bound * 2 ** blow_up_factor)) for i in range(degree_bound * 2 ** blow_up_factor)], 2 ** blow_up_factor) - proof = FRI.prove_low_degree(evals, 2 ** blow_up_factor, degree_bound, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, debug=False) - FRI.verify_low_degree(degree_bound, 2 ** blow_up_factor, proof, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, debug=False) + proof = FRI.prove_low_degree(evals, 2 ** blow_up_factor, degree_bound, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, MerlinTranscript(b'test'), debug=False) + FRI.verify_low_degree(degree_bound, 2 ** blow_up_factor, proof, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, MerlinTranscript(b'test'), debug=False) - def test_open(self): + def test_prove(self): from sage.all import GF from field import magic from random import randint @@ -63,7 +63,7 @@ def test_open(self): point = coset ** 0 * Fp.primitive_element() evals = [randint(0, 193) for i in range(evals_size)] value = UniPolynomial.uni_eval_from_evals(evals, point, [coset ** i for i in range(len(evals))]) - proof = FRI.open(evals, rate, point, coset, [coset ** i for i in range(evals_size * rate)], debug=False) + proof = FRI.prove(evals, rate, point, coset, [coset ** i for i in range(evals_size * rate)], debug=False) FRI.verify(proof['degree_bound'], evals_size, rate, proof, point, value, [coset ** i for i in range(evals_size * rate)], coset, debug=False)