Skip to content

Commit

Permalink
feat(fri): implement FRI protocol with open and verify methods
Browse files Browse the repository at this point in the history
Implement FRI (Fast Reed-Solomon Interactive Oracle Proof) protocol
with open and verify methods. Add helper functions for RS encoding,
folding, and query verification. Update tests to cover new functionality.
  • Loading branch information
ahy231 committed Oct 10, 2024
1 parent fb97cde commit 7213571
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 45 deletions.
184 changes: 139 additions & 45 deletions src/fri.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,105 @@
from merkle import MerkleTree, verify_decommitment
from merlin.merlin_transcript import MerlinTranscript
from utils import from_bytes, log_2, is_power_of_two
from utils import from_bytes, log_2, is_power_of_two, next_power_of_two
from unipolynomial import UniPolynomial

class FRI:
security_level = 128

@classmethod
def open(cls, evals, rate, point, gen, domain, debug=False):
if debug: print("evals:", evals)
N = len(evals)
degree_bound = next_power_of_two(N - 1) if not is_power_of_two(N - 1) else N - 1
if debug: print("degree_bound:", degree_bound)
coeffs = UniPolynomial.compute_coeffs_from_evals_fast(evals, domain[:N])
if debug: print("coeffs:", coeffs)
val = UniPolynomial.uni_eval_from_evals(evals, point, domain[:N])
if debug: print("val:", val)
assert len(domain) == N * rate, f"domain: {domain}, N: {N}, rate: {rate}"
code = cls.rs_encode_single(coeffs, domain, rate)
if debug: print("code:", code)
assert len(code) == N * rate, f"code: {code}, degree_bound: {degree_bound}, rate: {rate}"

code_tree = MerkleTree(code)
transcript = MerlinTranscript(b"FRI.open")
transcript.append_message(b"code", code_tree.root.encode('ascii'))

quotient = [(code[i] - val) / (domain[i] - point) for i in range(len(code))]

quotient_tree = MerkleTree(quotient)
transcript.append_message(b"quotient", quotient_tree.root.encode('ascii'))

z = from_bytes(transcript.challenge_bytes(b"z", 4)) % len(code)
code_at_z_proof = code_tree.get_authentication_path(z)
quotient_at_z_proof = quotient_tree.get_authentication_path(z)

if debug:
print('z:', z)
print('code[z]:', code[z])
print('quotient[z]:', quotient[z])
print('domain[z]:', domain[z])
print('point:', point)
print('value:', val)
assert code[z] - val == quotient[z] * (domain[z] - point), \
"failed to generate quotient, code: {}, quotient: {}, val: {}, z: {}, point: {}"\
.format(code, quotient, val, z, point)

num_verifier_queries = cls.security_level // log_2(rate)
if cls.security_level % log_2(rate) != 0:
num_verifier_queries += 1

low_degree_proof = cls.prove_low_degree(quotient, rate, degree_bound, gen, num_verifier_queries, debug)

return {
'low_degree_proof': low_degree_proof,
'code_commitment': code_tree.root,
'quotient_commitment': quotient_tree.root,
'code_at_z_proof': code_at_z_proof,
'quotient_at_z_proof': quotient_at_z_proof,
'code_at_z': code[z],
'quotient_at_z': quotient[z],
'degree_bound': degree_bound,
}

@classmethod
def verify(cls, degree_bound, evals_size, rate, proof, point, value, domain, gen, debug=False):

assert degree_bound >= proof['degree_bound']
degree_bound = proof['degree_bound']

transcript = MerlinTranscript(b"FRI.open")

code_commitment = proof['code_commitment']
quotient_commitment = proof['quotient_commitment']

transcript.append_message(b"code", code_commitment.encode('ascii'))
transcript.append_message(b"quotient", quotient_commitment.encode('ascii'))

z = from_bytes(transcript.challenge_bytes(b"z", 4)) % (evals_size * rate)
code_at_z_proof = proof['code_at_z_proof']
quotient_at_z_proof = proof['quotient_at_z_proof']

if debug:
print('z: ', z)
print('code_at_z: ', proof['code_at_z'])
print('quotient_at_z: ', proof['quotient_at_z'])
print('point:', point)
print('value:', value)

assert verify_decommitment(z, proof['code_at_z'], code_at_z_proof, code_commitment), f"failed to check decommitment at code_at_z, z: {z}, code_at_z: {proof['code_at_z']}, code_commitment: {code_commitment}"
assert verify_decommitment(z, proof['quotient_at_z'], quotient_at_z_proof, quotient_commitment), f"failed to check decommitment at quotient_at_z, z: {z}, quotient_at_z: {proof['quotient_at_z']}, quotient_commitment: {quotient_commitment}"

assert proof['code_at_z'] - value == proof['quotient_at_z'] * (domain[z] - point)

num_verifier_queries = cls.security_level // log_2(rate)
if cls.security_level % log_2(rate) != 0:
num_verifier_queries += 1

cls.verify_low_degree(degree_bound, rate, proof['low_degree_proof'], gen, num_verifier_queries, debug)

@staticmethod
def prove_low_degree(evals, degree_bound, coset, num_verifier_queries, debug=False):
def prove_low_degree(evals, rate, degree_bound, gen, num_verifier_queries, debug=False):
assert is_power_of_two(degree_bound)

first_tree = MerkleTree(evals)
Expand All @@ -20,8 +115,8 @@ def prove_low_degree(evals, degree_bound, coset, num_verifier_queries, debug=Fal
for _ in range(log_2(degree_bound)):
if debug: print("evals:", evals)
if debug: print("alpha:", alpha)
if debug: print("coset:", coset)
evals = FRI.fold(evals, alpha, coset)
if debug: print("generator:", gen)
evals = FRI.fold(evals, alpha, gen)
tree = MerkleTree(evals)
trees.append(tree)
tree_evals.append(evals)
Expand All @@ -30,16 +125,24 @@ def prove_low_degree(evals, degree_bound, coset, num_verifier_queries, debug=Fal
alpha = transcript.challenge_bytes(b"alpha", 4)
alpha = from_bytes(alpha)

coset *= coset
gen *= gen

if debug:
assert len(evals) == rate, f"evals: {evals}, rate: {rate}"
for i in range(len(evals)):
if i != 0:
assert evals[i] == evals[0], f"evals: {evals}"

# query phase
query_paths, merkle_paths = FRI.query_phase(transcript, first_tree, evals_copy, trees, tree_evals, len(evals_copy), num_verifier_queries, debug)
assert len(evals_copy) == degree_bound * rate, f"evals_copy: {evals_copy}, degree_bound: {degree_bound}, rate: {rate}"
query_paths, merkle_paths = FRI.query_phase(transcript, first_tree, evals_copy, trees, tree_evals, degree_bound * rate, num_verifier_queries, debug)

return {
'query_paths': query_paths,
'merkle_paths': merkle_paths,
'first_oracle': first_tree.root,
'intermediate_oracles': [tree.root for tree in trees]
'intermediate_oracles': [tree.root for tree in trees],
'degree_bound': degree_bound,
}

# f(x) = f0(x^2) + x * f1(x^2)
Expand All @@ -50,21 +153,30 @@ def prove_low_degree(evals, degree_bound, coset, num_verifier_queries, debug=Fal
# f0(x^2) = (f(x) + f(-x)) / 2
# f1(x^2) = (f(x) - f(-x)) / 2x
@staticmethod
def fold(evals, alpha, coset):
def fold(evals, alpha, g, debug=False):
assert len(evals) % 2 == 0

half = len(evals) // 2
f0_evals = [(evals[i] + evals[half + i]) // 2 for i in range(half)]
f1_evals = [(evals[i] - evals[half + i]) // (2 * coset) for i in range(half)]
f0_evals = [(evals[i] + evals[half + i]) / 2 for i in range(half)]
f1_evals = [(evals[i] - evals[half + i]) / (2 * g ** i) for i in range(half)]

if debug:
x = g ** 5
f_x = UniPolynomial.uni_eval_from_evals(evals, x, [g ** i for i in range(len(evals))])
f0_x = UniPolynomial.uni_eval_from_evals(f0_evals, x ** 2, [(g ** 2) ** i for i in range(len(f0_evals))])
f1_x = UniPolynomial.uni_eval_from_evals(f1_evals, x ** 2, [(g ** 2) ** i for i in range(len(f1_evals))])
assert f_x == f0_x + x * f1_x, f"failed to fold, f_x: {f_x}, f0_x: {f0_x}, f1_x: {f1_x}, alpha: {alpha}"

return [x + alpha * y for x, y in zip(f0_evals, f1_evals)]


@staticmethod
def verify_low_degree(degree_bound, proof, coset, num_verifier_queries, debug=False):
def verify_low_degree(degree_bound, rate, proof, gen, num_verifier_queries, debug=False):
log_degree_bound = log_2(degree_bound)
log_evals = log_2(len(evals))
T = [coset**(2 ** j) for j in range(0, log_evals)]
FRI.verify_queries(proof, log_degree_bound, len(evals), num_verifier_queries, T, debug)
log_evals = log_2(degree_bound * rate)
T = [[(gen**(2 ** j)) ** i for i in range(2 ** (log_evals - j - 1))] for j in range(0, log_evals)]
if debug: print("T:", T)
FRI.verify_queries(proof, log_degree_bound, degree_bound * rate, num_verifier_queries, T, debug)

@staticmethod
def query_phase(transcript: MerlinTranscript, first_tree: MerkleTree, first_oracle, trees: list, oracles: list, num_vars, num_verifier_queries, debug=False):
Expand Down Expand Up @@ -155,16 +267,17 @@ def verify_queries(proof, k, num_vars, num_verifier_queries, T, debug=False):
if debug: print("x1:", x1)

if i != len(mps) - 1:
coset = T[i]
if debug: print("coset:", coset)
table = T[i]
if debug: print("table:", table)
f_code_folded = cur_path[i + 1][0 if x0 < num_vars_copy / 4 else 1]
alpha = fold_challenges[i]
if debug: assert x0 < len(table), f"x0: {x0}, table: {table}"
if debug: print("f_code_folded:", f_code_folded)
if debug: print("expected:", ((code_left + code_right)/2 + alpha * (code_left - code_right)/(2*coset)))
if debug: print("expected:", ((code_left + code_right)/2 + alpha * (code_left - code_right)/(2*table[x0])))
if debug: print("code_left:", code_left)
if debug: print("code_right:", code_right)
if debug: print("alpha:", alpha)
assert f_code_folded == ((code_left + code_right)/2 + alpha * (code_left - code_right)/(2*coset)), f"failed to check fri, i: {i}, x0: {x0}, x1: {x1}, code_left: {code_left}, code_right: {code_right}, alpha: {alpha}, coset: {coset}"
assert f_code_folded == ((code_left + code_right)/2 + alpha * (code_left - code_right)/(2*table[x0])), f"failed to check fri, i: {i}, x0: {x0}, x1: {x1}, code_left: {code_left}, code_right: {code_right}, alpha: {alpha}, generator: {table}"

if i == 0:
assert verify_decommitment(x0, code_left, mp, proof['first_oracle']), "failed to check decommitment at first level"
Expand All @@ -174,31 +287,12 @@ def verify_queries(proof, k, num_vars, num_verifier_queries, T, debug=False):
num_vars_copy >>= 1
q = x0

@staticmethod
def rs_encode_single(m, alpha, c):
k0 = len(m)
code = [None] * (k0 * c)
for i in range(k0 * c):
# Compute f_m(alpha[i])
code[i] = sum(m[j] * (alpha[i] ** j) for j in range(k0))
return code

def rs_encode_single(m, alpha, c):
k0 = len(m)
code = [None] * (k0 * c)
for i in range(k0 * c):
# Compute f_m(alpha[i])
code[i] = sum(m[j] * (alpha[i] ** j) for j in range(k0))
return code


if __name__ == "__main__":
from sage.all import *
from field import magic
from random import randint

Fp = magic(GF(193))

assert Fp.primitive_element() ** 192 == 1

degree_bound = 8
blow_up_factor = 4
num_verifier_queries = 8

assert is_power_of_two(degree_bound)

evals = rs_encode_single([randint(0, 193) for _ in range(degree_bound)], [Fp.primitive_element() ** (i * 192 // (degree_bound * 2 ** blow_up_factor)) for i in range(degree_bound * 2 ** blow_up_factor)], 2 ** blow_up_factor)
proof = FRI.prove_low_degree(evals, degree_bound, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, debug=False)
FRI.verify_low_degree(degree_bound, proof, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, debug=False)
71 changes: 71 additions & 0 deletions tests/test_fri.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from unittest import TestCase, main
import sys

sys.path.append('../src')
sys.path.append('src')

from fri import FRI
from utils import is_power_of_two
from unipolynomial import UniPolynomial
class TestFRI(TestCase):
def setUp(self):
# Set up a scalar field for testing (e.g., integers modulo a prime)
prime = 193 # A small prime for testing
UniPolynomial.set_scalar(int, lambda x: x % prime)

def test_fold(self):
from sage.all import GF
from field import magic

Fp = magic(GF(193))

evals = FRI.rs_encode_single([2, 3, 4, 5], [Fp.primitive_element() ** (i * 192 // 16) for i in range(16)], 4)
coset = Fp.primitive_element() ** (192 // len(evals))
alpha = Fp(7)

evals = FRI.fold(evals, alpha, coset, debug=True)
coset = coset ** 2
evals = FRI.fold(evals, alpha, coset, debug=True)

assert evals[0] == evals[1] == evals[2] == evals[3]

def test_low_degree(self):
from sage.all import GF
from field import magic
from random import randint

Fp = magic(GF(193))

assert Fp.primitive_element() ** 192 == 1

degree_bound = 8
blow_up_factor = 4
num_verifier_queries = 8

assert is_power_of_two(degree_bound)

evals = FRI.rs_encode_single([randint(0, 193) for _ in range(degree_bound)], [Fp.primitive_element() ** (i * 192 // (degree_bound * 2 ** blow_up_factor)) for i in range(degree_bound * 2 ** blow_up_factor)], 2 ** blow_up_factor)
proof = FRI.prove_low_degree(evals, 2 ** blow_up_factor, degree_bound, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, debug=False)
FRI.verify_low_degree(degree_bound, 2 ** blow_up_factor, proof, Fp.primitive_element() ** (192 // len(evals)), num_verifier_queries, debug=False)

def test_open(self):
from sage.all import GF
from field import magic
from random import randint

Fp = magic(GF(193))

assert Fp.primitive_element() ** 192 == 1

rate = 4
evals_size = 4
coset = Fp.primitive_element() ** (192 // (evals_size * rate))
point = coset ** 0 * Fp.primitive_element()
evals = [randint(0, 193) for i in range(evals_size)]
value = UniPolynomial.uni_eval_from_evals(evals, point, [coset ** i for i in range(len(evals))])
proof = FRI.open(evals, rate, point, coset, [coset ** i for i in range(evals_size * rate)], debug=False)
FRI.verify(proof['degree_bound'], evals_size, rate, proof, point, value, [coset ** i for i in range(evals_size * rate)], coset, debug=False)


if __name__ == '__main__':
main()

0 comments on commit 7213571

Please sign in to comment.