diff --git a/src/fri.py b/src/fri.py index 7ca7666..4df8916 100644 --- a/src/fri.py +++ b/src/fri.py @@ -7,7 +7,7 @@ class FRI: security_level = 128 @classmethod - def prove(cls, evals, rate, point, gen, domain, debug=False): + def commit(cls, evals, rate, domain, debug=False): if debug: print("evals:", evals) N = len(evals) assert is_power_of_two(N) @@ -18,13 +18,14 @@ def prove(cls, evals, rate, point, gen, domain, debug=False): code = cls.rs_encode_single(coeffs, domain, rate) if debug: print("code:", code) assert len(code) == N * rate, f"code: {code}, degree_bound: {degree_bound}, rate: {rate}" - val = UniPolynomial.uni_eval_from_evals(evals, point, domain[:N]) - if debug: print("val:", val) - assert len(domain) == N * rate, f"domain: {domain}, N: {N}, rate: {rate}" - code_tree = MerkleTree(code) - transcript = MerlinTranscript(b"FRI") - transcript.append_message(b"code", code_tree.root.encode('ascii')) + return MerkleTree(code), code + + @classmethod + def prove(cls, code, code_tree, val, point, domain, rate, degree_bound, gen, transcript, debug=False): + if debug: print("val:", val) + assert len(domain) == degree_bound * rate, f"domain: {domain}, degree_bound: {degree_bound}, rate: {rate}" + assert isinstance(transcript, MerlinTranscript), f"transcript: {transcript}" quotient = [(code[i] - val) / (domain[i] - point) for i in range(len(code))] @@ -65,12 +66,12 @@ def prove(cls, evals, rate, point, gen, domain, debug=False): } @classmethod - def verify(cls, degree_bound, evals_size, rate, proof, point, value, domain, gen, debug=False): + def verify(cls, degree_bound, rate, proof, point, value, domain, gen, transcript, debug=False): assert degree_bound >= proof['degree_bound'] degree_bound = proof['degree_bound'] - transcript = MerlinTranscript(b"FRI") + assert isinstance(transcript, MerlinTranscript), f"transcript: {transcript}" code_commitment = proof['code_commitment'] quotient_commitment = proof['quotient_commitment'] @@ -79,7 +80,7 @@ def verify(cls, degree_bound, evals_size, rate, proof, point, value, domain, gen transcript.append_message(b"quotient", quotient_commitment.encode('ascii')) transcript.append_message(b"value at z", str(value).encode('ascii')) - z = from_bytes(transcript.challenge_bytes(b"z", 4)) % (evals_size * rate) + z = from_bytes(transcript.challenge_bytes(b"z", 4)) % (degree_bound * rate) code_at_z_proof = proof['code_at_z_proof'] quotient_at_z_proof = proof['quotient_at_z_proof'] @@ -268,9 +269,9 @@ def verify_queries(proof, k, num_vars, num_verifier_queries, T, transcript, debu if debug: print("x0:", x0) if debug: print("x1:", x1) + table = T[i] + if debug: print("table:", table) if i != len(mps) - 1: - table = T[i] - if debug: print("table:", table) f_code_folded = cur_path[i + 1][0 if x0 < num_vars_copy / 4 else 1] alpha = fold_challenges[i] if debug: assert x0 < len(table), f"x0: {x0}, table: {table}" diff --git a/src/unipolynomial.py b/src/unipolynomial.py index a2e6782..5b505a2 100644 --- a/src/unipolynomial.py +++ b/src/unipolynomial.py @@ -644,27 +644,27 @@ def interpolate(cls, evals, domain): # barycentric interpolation @classmethod - def barycentric_weights(cls, D): + def barycentric_weights(cls, D, one=1): n = len(D) weights = [1] * n for i in range(n): # weights[i] = product([(D[i] - D[j]) if i !=j else Fp(1) for j in range(n)]) for j in range(n): if i==j: - weights[i] *= 1 + weights[i] *= one continue weights[i] *= (D[i] - D[j]) - weights[i] = 1/weights[i] + weights[i] = one / weights[i] return weights @classmethod - def uni_eval_from_evals(cls, evals, z, D): + def uni_eval_from_evals(cls, evals, z, D, one=1): n = len(evals) if n != len(D): raise ValueError("Domain size should be equal to the length of evaluations") if z in D: return evals[D.index(z)] - weights = cls.barycentric_weights(D) + weights = cls.barycentric_weights(D, one) # print("weights={}".format(weights)) e_vec = [weights[i] / (z - D[i]) for i in range(n)] numerator = sum([e_vec[i] * evals[i] for i in range(n)]) diff --git a/tests/test_fri.py b/tests/test_fri.py index dcb6f65..23a4891 100644 --- a/tests/test_fri.py +++ b/tests/test_fri.py @@ -23,9 +23,9 @@ def test_fold(self): coset = Fp.primitive_element() ** (192 // len(evals)) alpha = Fp(7) - evals = FRI.fold(evals, alpha, coset, debug=True) + evals = FRI.fold(evals, alpha, coset, debug=False) coset = coset ** 2 - evals = FRI.fold(evals, alpha, coset, debug=True) + evals = FRI.fold(evals, alpha, coset, debug=False) assert evals[0] == evals[1] == evals[2] == evals[3] @@ -52,6 +52,7 @@ def test_prove(self): from sage.all import GF from field import magic from random import randint + from merlin.merlin_transcript import MerlinTranscript Fp = magic(GF(193)) @@ -60,11 +61,15 @@ def test_prove(self): rate = 4 evals_size = 4 coset = Fp.primitive_element() ** (192 // (evals_size * rate)) - point = coset ** randint(evals_size * rate, 192) * Fp.primitive_element() - evals = [randint(0, 193) for i in range(evals_size)] + point = Fp.primitive_element() + evals = [i for i in range(evals_size)] value = UniPolynomial.uni_eval_from_evals(evals, point, [coset ** i for i in range(len(evals))]) - proof = FRI.prove(evals, rate, point, coset, [coset ** i for i in range(evals_size * rate)], debug=False) - FRI.verify(proof['degree_bound'], evals_size, rate, proof, point, value, [coset ** i for i in range(evals_size * rate)], coset, debug=False) + domain = [coset ** i for i in range(evals_size * rate)] + code_tree, code = FRI.commit(evals, rate, domain, debug=False) + transcript = MerlinTranscript(b'test') + transcript.append_message(b"code", code_tree.root.encode('ascii')) + proof = FRI.prove(code, code_tree, value, point, domain, rate, evals_size, coset, transcript, debug=False) + FRI.verify(evals_size, rate, proof, point, value, domain, coset, MerlinTranscript(b'test'), debug=False) if __name__ == '__main__':