diff --git a/src/fri.py b/src/fri.py index 3de4afa..7acdbdc 100644 --- a/src/fri.py +++ b/src/fri.py @@ -1,10 +1,105 @@ from merkle import MerkleTree, verify_decommitment from merlin.merlin_transcript import MerlinTranscript -from utils import from_bytes, log_2, is_power_of_two +from utils import from_bytes, log_2, is_power_of_two, next_power_of_two +from unipolynomial import UniPolynomial class FRI: + security_level = 128 + + @classmethod + def open(cls, evals, rate, point, gen, domain, debug=False): + if debug: print("evals:", evals) + N = len(evals) + degree_bound = next_power_of_two(N - 1) if not is_power_of_two(N - 1) else N - 1 + if debug: print("degree_bound:", degree_bound) + coeffs = UniPolynomial.compute_coeffs_from_evals_fast(evals, domain[:N]) + if debug: print("coeffs:", coeffs) + val = UniPolynomial.uni_eval_from_evals(evals, point, domain[:N]) + if debug: print("val:", val) + assert len(domain) == N * rate, f"domain: {domain}, N: {N}, rate: {rate}" + code = cls.rs_encode_single(coeffs, domain, rate) + if debug: print("code:", code) + assert len(code) == N * rate, f"code: {code}, degree_bound: {degree_bound}, rate: {rate}" + + code_tree = MerkleTree(code) + transcript = MerlinTranscript(b"FRI.open") + transcript.append_message(b"code", code_tree.root.encode('ascii')) + + quotient = [(code[i] - val) / (domain[i] - point) for i in range(len(code))] + + quotient_tree = MerkleTree(quotient) + transcript.append_message(b"quotient", quotient_tree.root.encode('ascii')) + + z = from_bytes(transcript.challenge_bytes(b"z", 4)) % len(code) + code_at_z_proof = code_tree.get_authentication_path(z) + quotient_at_z_proof = quotient_tree.get_authentication_path(z) + + if debug: + print('z:', z) + print('code[z]:', code[z]) + print('quotient[z]:', quotient[z]) + print('domain[z]:', domain[z]) + print('point:', point) + print('value:', val) + assert code[z] - val == quotient[z] * (domain[z] - point), \ + "failed to generate quotient, code: {}, quotient: {}, val: {}, z: {}, point: {}"\ + .format(code, quotient, val, z, point) + + num_verifier_queries = cls.security_level // log_2(rate) + if cls.security_level % log_2(rate) != 0: + num_verifier_queries += 1 + + low_degree_proof = cls.prove_low_degree(quotient, rate, degree_bound, gen, num_verifier_queries, debug) + + return { + 'low_degree_proof': low_degree_proof, + 'code_commitment': code_tree.root, + 'quotient_commitment': quotient_tree.root, + 'code_at_z_proof': code_at_z_proof, + 'quotient_at_z_proof': quotient_at_z_proof, + 'code_at_z': code[z], + 'quotient_at_z': quotient[z], + 'degree_bound': degree_bound, + } + + @classmethod + def verify(cls, degree_bound, evals_size, rate, proof, point, value, domain, gen, debug=False): + + assert degree_bound >= proof['degree_bound'] + degree_bound = proof['degree_bound'] + + transcript = MerlinTranscript(b"FRI.open") + + code_commitment = proof['code_commitment'] + quotient_commitment = proof['quotient_commitment'] + + transcript.append_message(b"code", code_commitment.encode('ascii')) + transcript.append_message(b"quotient", quotient_commitment.encode('ascii')) + + z = from_bytes(transcript.challenge_bytes(b"z", 4)) % (evals_size * rate) + code_at_z_proof = proof['code_at_z_proof'] + quotient_at_z_proof = proof['quotient_at_z_proof'] + + if debug: + print('z: ', z) + print('code_at_z: ', proof['code_at_z']) + print('quotient_at_z: ', proof['quotient_at_z']) + print('point:', point) + print('value:', value) + + assert verify_decommitment(z, proof['code_at_z'], code_at_z_proof, code_commitment), f"failed to check decommitment at code_at_z, z: {z}, code_at_z: {proof['code_at_z']}, code_commitment: {code_commitment}" + assert verify_decommitment(z, proof['quotient_at_z'], quotient_at_z_proof, quotient_commitment), f"failed to check decommitment at quotient_at_z, z: {z}, quotient_at_z: {proof['quotient_at_z']}, quotient_commitment: {quotient_commitment}" + + assert proof['code_at_z'] - value == proof['quotient_at_z'] * (domain[z] - point) + + num_verifier_queries = cls.security_level // log_2(rate) + if cls.security_level % log_2(rate) != 0: + num_verifier_queries += 1 + + cls.verify_low_degree(degree_bound, rate, proof['low_degree_proof'], gen, num_verifier_queries, debug) + @staticmethod - def prove_low_degree(evals, degree_bound, coset, num_verifier_queries, debug=False): + def prove_low_degree(evals, rate, degree_bound, gen, num_verifier_queries, debug=False): assert is_power_of_two(degree_bound) first_tree = MerkleTree(evals) @@ -20,8 +115,8 @@ def prove_low_degree(evals, degree_bound, coset, num_verifier_queries, debug=Fal for _ in range(log_2(degree_bound)): if debug: print("evals:", evals) if debug: print("alpha:", alpha) - if debug: print("coset:", coset) - evals = FRI.fold(evals, alpha, coset) + if debug: print("generator:", gen) + evals = FRI.fold(evals, alpha, gen) tree = MerkleTree(evals) trees.append(tree) tree_evals.append(evals) @@ -30,16 +125,24 @@ def prove_low_degree(evals, degree_bound, coset, num_verifier_queries, debug=Fal alpha = transcript.challenge_bytes(b"alpha", 4) alpha = from_bytes(alpha) - coset *= coset + gen *= gen + + if debug: + assert len(evals) == rate, f"evals: {evals}, rate: {rate}" + for i in range(len(evals)): + if i != 0: + assert evals[i] == evals[0], f"evals: {evals}" # query phase - query_paths, merkle_paths = FRI.query_phase(transcript, first_tree, evals_copy, trees, tree_evals, len(evals_copy), num_verifier_queries, debug) + assert len(evals_copy) == degree_bound * rate, f"evals_copy: {evals_copy}, degree_bound: {degree_bound}, rate: {rate}" + query_paths, merkle_paths = FRI.query_phase(transcript, first_tree, evals_copy, trees, tree_evals, degree_bound * rate, num_verifier_queries, debug) return { 'query_paths': query_paths, 'merkle_paths': merkle_paths, 'first_oracle': first_tree.root, - 'intermediate_oracles': [tree.root for tree in trees] + 'intermediate_oracles': [tree.root for tree in trees], + 'degree_bound': degree_bound, } # f(x) = f0(x^2) + x * f1(x^2) @@ -50,21 +153,30 @@ def prove_low_degree(evals, degree_bound, coset, num_verifier_queries, debug=Fal # f0(x^2) = (f(x) + f(-x)) / 2 # f1(x^2) = (f(x) - f(-x)) / 2x @staticmethod - def fold(evals, alpha, coset): + def fold(evals, alpha, g, debug=False): assert len(evals) % 2 == 0 half = len(evals) // 2 - f0_evals = [(evals[i] + evals[half + i]) // 2 for i in range(half)] - f1_evals = [(evals[i] - evals[half + i]) // (2 * coset) for i in range(half)] + f0_evals = [(evals[i] + evals[half + i]) / 2 for i in range(half)] + f1_evals = [(evals[i] - evals[half + i]) / (2 * g ** i) for i in range(half)] + + if debug: + x = g ** 5 + f_x = UniPolynomial.uni_eval_from_evals(evals, x, [g ** i for i in range(len(evals))]) + f0_x = UniPolynomial.uni_eval_from_evals(f0_evals, x ** 2, [(g ** 2) ** i for i in range(len(f0_evals))]) + f1_x = UniPolynomial.uni_eval_from_evals(f1_evals, x ** 2, [(g ** 2) ** i for i in range(len(f1_evals))]) + assert f_x == f0_x + x * f1_x, f"failed to fold, f_x: {f_x}, f0_x: {f0_x}, f1_x: {f1_x}, alpha: {alpha}" + return [x + alpha * y for x, y in zip(f0_evals, f1_evals)] @staticmethod - def verify_low_degree(degree_bound, proof, coset, num_verifier_queries, debug=False): + def verify_low_degree(degree_bound, rate, proof, gen, num_verifier_queries, debug=False): log_degree_bound = log_2(degree_bound) - log_evals = log_2(len(evals)) - T = [coset**(2 ** j) for j in range(0, log_evals)] - FRI.verify_queries(proof, log_degree_bound, len(evals), num_verifier_queries, T, debug) + log_evals = log_2(degree_bound * rate) + T = [[(gen**(2 ** j)) ** i for i in range(2 ** (log_evals - j - 1))] for j in range(0, log_evals)] + if debug: print("T:", T) + FRI.verify_queries(proof, log_degree_bound, degree_bound * rate, num_verifier_queries, T, debug) @staticmethod def query_phase(transcript: MerlinTranscript, first_tree: MerkleTree, first_oracle, trees: list, oracles: list, num_vars, num_verifier_queries, debug=False): @@ -155,16 +267,17 @@ def verify_queries(proof, k, num_vars, num_verifier_queries, T, debug=False): if debug: print("x1:", x1) if i != len(mps) - 1: - coset = T[i] - if debug: print("coset:", coset) + table = T[i] + if debug: print("table:", table) f_code_folded = cur_path[i + 1][0 if x0 < num_vars_copy / 4 else 1] alpha = fold_challenges[i] + if debug: assert x0 < len(table), f"x0: {x0}, table: {table}" if debug: print("f_code_folded:", f_code_folded) - if debug: print("expected:", ((code_left + code_right)/2 + alpha * (code_left - code_right)/(2*coset))) + if debug: print("expected:", ((code_left + code_right)/2 + alpha * (code_left - code_right)/(2*table[x0]))) if debug: print("code_left:", code_left) if debug: print("code_right:", code_right) if debug: print("alpha:", alpha) - assert f_code_folded == ((code_left + code_right)/2 + alpha * (code_left - code_right)/(2*coset)), f"failed to check fri, i: {i}, x0: {x0}, x1: {x1}, code_left: {code_left}, code_right: {code_right}, alpha: {alpha}, coset: {coset}" + assert f_code_folded == ((code_left + code_right)/2 + alpha * (code_left - code_right)/(2*table[x0])), f"failed to check fri, i: {i}, x0: {x0}, x1: {x1}, code_left: {code_left}, code_right: {code_right}, alpha: {alpha}, generator: {table}" if i == 0: assert verify_decommitment(x0, code_left, mp, proof['first_oracle']), "failed to check decommitment at first level" @@ -174,31 +287,12 @@ def verify_queries(proof, k, num_vars, num_verifier_queries, T, debug=False): num_vars_copy >>= 1 q = x0 + @staticmethod + def rs_encode_single(m, alpha, c): + k0 = len(m) + code = [None] * (k0 * c) + for i in range(k0 * c): + # Compute f_m(alpha[i]) + code[i] = sum(m[j] * (alpha[i] ** j) for j in range(k0)) + return code -def rs_encode_single(m, alpha, c): - k0 = len(m) - code = [None] * (k0 * c) - for i in range(k0 * c): - # Compute f_m(alpha[i]) - code[i] = sum(m[j] * (alpha[i] ** j) for j in range(k0)) - return code - - -if __name__ == "__main__": - from sage.all import * - from field import magic - from random import randint - - Fp = magic(GF(193)) - - assert Fp.primitive_element() ** 192 == 1 - - degree_bound = 8 - blow_up_factor = 4 - num_verifier_queries = 8 - - assert is_power_of_two(degree_bound) - - evals = rs_encode_single([randint(0, 193) for _ in range(degree_bound)], [Fp.primitive_element() ** (i * 192 // (degree_bound * 2 ** blow_up_factor)) for i in range(degree_bound * 2 ** blow_up_factor)], 2 ** blow_up_factor) - proof = FRI.prove_low_degree(evals, degree_bound, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, debug=False) - FRI.verify_low_degree(degree_bound, proof, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, debug=False) diff --git a/tests/test_fri.py b/tests/test_fri.py new file mode 100644 index 0000000..8cacbd4 --- /dev/null +++ b/tests/test_fri.py @@ -0,0 +1,71 @@ +from unittest import TestCase, main +import sys + +sys.path.append('../src') +sys.path.append('src') + +from fri import FRI +from utils import is_power_of_two +from unipolynomial import UniPolynomial +class TestFRI(TestCase): + def setUp(self): + # Set up a scalar field for testing (e.g., integers modulo a prime) + prime = 193 # A small prime for testing + UniPolynomial.set_scalar(int, lambda x: x % prime) + + def test_fold(self): + from sage.all import GF + from field import magic + + Fp = magic(GF(193)) + + evals = FRI.rs_encode_single([2, 3, 4, 5], [Fp.primitive_element() ** (i * 192 // 16) for i in range(16)], 4) + coset = Fp.primitive_element() ** (192 // len(evals)) + alpha = Fp(7) + + evals = FRI.fold(evals, alpha, coset, debug=True) + coset = coset ** 2 + evals = FRI.fold(evals, alpha, coset, debug=True) + + assert evals[0] == evals[1] == evals[2] == evals[3] + + def test_low_degree(self): + from sage.all import GF + from field import magic + from random import randint + + Fp = magic(GF(193)) + + assert Fp.primitive_element() ** 192 == 1 + + degree_bound = 8 + blow_up_factor = 4 + num_verifier_queries = 8 + + assert is_power_of_two(degree_bound) + + evals = FRI.rs_encode_single([randint(0, 193) for _ in range(degree_bound)], [Fp.primitive_element() ** (i * 192 // (degree_bound * 2 ** blow_up_factor)) for i in range(degree_bound * 2 ** blow_up_factor)], 2 ** blow_up_factor) + proof = FRI.prove_low_degree(evals, 2 ** blow_up_factor, degree_bound, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, debug=False) + FRI.verify_low_degree(degree_bound, 2 ** blow_up_factor, proof, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, debug=False) + + def test_open(self): + from sage.all import GF + from field import magic + from random import randint + + Fp = magic(GF(193)) + + assert Fp.primitive_element() ** 192 == 1 + + rate = 4 + evals_size = 4 + coset = Fp.primitive_element() ** (192 // (evals_size * rate)) + point = coset ** 0 * Fp.primitive_element() + evals = [randint(0, 193) for i in range(evals_size)] + value = UniPolynomial.uni_eval_from_evals(evals, point, [coset ** i for i in range(len(evals))]) + proof = FRI.open(evals, rate, point, coset, [coset ** i for i in range(evals_size * rate)], debug=False) + FRI.verify(proof['degree_bound'], evals_size, rate, proof, point, value, [coset ** i for i in range(evals_size * rate)], coset, debug=False) + + +if __name__ == '__main__': + main()