diff --git a/src/fri.py b/src/fri.py index 23b47fe..7ca7666 100644 --- a/src/fri.py +++ b/src/fri.py @@ -145,6 +145,7 @@ def prove_low_degree(evals, rate, degree_bound, gen, num_verifier_queries, trans 'first_oracle': first_tree.root, 'intermediate_oracles': [tree.root for tree in trees], 'degree_bound': degree_bound, + 'final_value': evals[0], } # f(x) = f0(x^2) + x * f1(x^2) @@ -278,7 +279,9 @@ def verify_queries(proof, k, num_vars, num_verifier_queries, T, transcript, debu if debug: print("code_left:", code_left) if debug: print("code_right:", code_right) if debug: print("alpha:", alpha) - assert f_code_folded == ((code_left + code_right)/2 + alpha * (code_left - code_right)/(2*table[x0])), f"failed to check fri, i: {i}, x0: {x0}, x1: {x1}, code_left: {code_left}, code_right: {code_right}, alpha: {alpha}, generator: {table}" + assert f_code_folded == (code_left + code_right)/2 + alpha * (code_left - code_right)/(2*table[x0]), f"failed to check fri, i: {i}, x0: {x0}, x1: {x1}, code_left: {code_left}, code_right: {code_right}, alpha: {alpha}, generator: {table}" + else: + assert proof["final_value"] == (code_left + code_right)/2 + alpha * (code_left - code_right)/(2*table[x0]), f"failed to check fri, i: {i}, x0: {x0}, x1: {x1}, code_left: {code_left}, code_right: {code_right}, alpha: {alpha}, generator: {table}, final_value: {proof['final_value']}" if i == 0: assert verify_decommitment(x0, code_left, mp, proof['first_oracle']), "failed to check decommitment at first level" diff --git a/tests/test_fri.py b/tests/test_fri.py index cf9af06..dcb6f65 100644 --- a/tests/test_fri.py +++ b/tests/test_fri.py @@ -60,7 +60,7 @@ def test_prove(self): rate = 4 evals_size = 4 coset = Fp.primitive_element() ** (192 // (evals_size * rate)) - point = coset ** 0 * Fp.primitive_element() + point = coset ** randint(evals_size * rate, 192) * Fp.primitive_element() evals = [randint(0, 193) for i in range(evals_size)] value = UniPolynomial.uni_eval_from_evals(evals, point, [coset ** i for i in range(len(evals))]) proof = FRI.prove(evals, rate, point, coset, [coset ** i for i in range(evals_size * rate)], debug=False)