Skip to content

Commit

Permalink
Merge pull request #7 from sec-bit/eliminate_mle1
Browse files Browse the repository at this point in the history
use mle2.py instead of mle.py
  • Loading branch information
ahy231 authored Sep 26, 2024
2 parents 8282f2a + 5570845 commit 9764d05
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 221 deletions.
124 changes: 63 additions & 61 deletions src/Basefold.ipynb

Large diffs are not rendered by default.

15 changes: 8 additions & 7 deletions src/Basefold.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import field
from random import randint
from mle import eqs_over_hypercube, uni_eval_from_evals, mle_eval_from_evals
from mle2 import MLEPolynomial
from unipolynomial import UniPolynomial
from merkle import MerkleTree, verify_decommitment
from merlin.merlin_transcript import MerlinTranscript
from sage.all import *
Expand Down Expand Up @@ -145,7 +146,7 @@ def prove_basefold_evaluation_arg_multilinear_basis(f_code, f_evals, us, v, k, k
assert len(T) == k, "wrong table size, k={}, len(T)={}".format(k, len(T))
f_code_copy = f_code[:]
f = f_evals[:]
eq = eqs_over_hypercube(us)
eq = MLEPolynomial.eqs_over_hypercube(us)

challenge_vec = []
sumcheck_sum = v
Expand Down Expand Up @@ -183,7 +184,7 @@ def prove_basefold_evaluation_arg_multilinear_basis(f_code, f_evals, us, v, k, k
if debug: print("> sumcheck: eq_folded = {}".format(eq))

# compute the new sum = h(alpha)
sumcheck_sum = uni_eval_from_evals([h_eval_at_0, h_eval_at_1, h_eval_at_2], alpha, [0,1,2])
sumcheck_sum = UniPolynomial.uni_eval_from_evals([h_eval_at_0, h_eval_at_1, h_eval_at_2], alpha, [Fp(0),Fp(1),Fp(2)])
if debug: print("> sumcheck: sumcheck_sum = {}".format(sumcheck_sum))

if debug: print("fri round {}".format(i))
Expand Down Expand Up @@ -288,7 +289,7 @@ def verify_basefold_evaluation_arg_multilinear_basis(N, commit, proof, us, v, d,
f_code_vec = proof['f_code_vec']
sumcheck_sum = v
half = n >> 1
eq_evals = eqs_over_hypercube(us)
eq_evals = MLEPolynomial.eqs_over_hypercube(us)

for i in range(k):
if debug: print("sumcheck round {}".format(i))
Expand All @@ -300,7 +301,7 @@ def verify_basefold_evaluation_arg_multilinear_basis(N, commit, proof, us, v, d,

alpha = challenge_vec[i]

sumcheck_sum = uni_eval_from_evals(h_evals, alpha, [0,1,2])
sumcheck_sum = UniPolynomial.uni_eval_from_evals(h_evals, alpha, [Fp(0),Fp(1),Fp(2)])

eq_low = eq_evals[:half]
eq_high = eq_evals[half:]
Expand All @@ -325,7 +326,7 @@ def verify_basefold_evaluation_arg_multilinear_basis(N, commit, proof, us, v, d,
if debug: print("f_eval_at_random={}".format(f_eval_at_random))
if debug: print("rs_encode([f_eval_at_random], k0=1, c=blowup_factor)=", rs_encode([f_eval_at_random], k0=1, c=blowup_factor))
assert rs_encode([f_eval_at_random], k0=1, c=blowup_factor) == f_code_folded, "❌: Encode(f(rs)) != f_code_0"
if debug: print("✅: Verified! fold({}) == encode(fold(f_eq)/fold(eq(us)))".format(f_code))
if debug: print("✅: Verified! fold({}) == encode(fold(f_eq)/fold(eq(us)))".format(commit))

return True

Expand Down Expand Up @@ -370,7 +371,7 @@ def basefold_fri_multilinear_basis(vs, table, c, debug=False):
ff_code = basefold_encode(m=ff, k0=2 ** log_k0, depth=log_n - log_k0, c=blowup_factor, G0=rs_encode, T=T)
commit = MerkleTree(ff_code)
point = [randint(0, 100) for _ in range(log_n)]
eval = mle_eval_from_evals(ff, point)
eval = MLEPolynomial.evaluate_from_evals(ff, point)

transcript = MerlinTranscript(b"verify queries")
transcript.append_message(b"commit.root", bytes(commit.root, 'ascii'))
Expand Down
5 changes: 4 additions & 1 deletion src/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def __rmul__(self, other):
def __truediv__(self, other):
return self._operate(other, lambda a, b: a / b, 'div')

def __rtruediv__(self, other):
return self._operate(other, lambda a, b: b / a, 'div')

def __pow__(self, exponent):
self._increment_count('mul') # Consider power as a series of multiplications
if isinstance(exponent, int):
Expand Down Expand Up @@ -107,7 +110,7 @@ def zero():

@classmethod
def random_element(cls):
return cls(random.randint(0, 139))
return cls(random.randint(0, 193))

def magic(Fp):
def magic_field(value):
Expand Down
138 changes: 0 additions & 138 deletions src/mle.py

This file was deleted.

34 changes: 30 additions & 4 deletions src/mle2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env sage -python

from utils import log_2, pow_2
from functools import reduce
from utils import log_2, pow_2, bits_le_with_width
# from sage.all import product, GF

class MLEPolynomial:
Expand All @@ -19,10 +20,22 @@ def eqs_over_hypercube(cls, rs):
half = 1
for i in range(k):
for j in range(half):
evals[j+half] = Fp(evals[j] * rs[i])
evals[j] = Fp(evals[j] - evals[j+half])
evals[j+half] = evals[j] * rs[i]
evals[j] = evals[j] - evals[j+half]
half *= 2
return evals

@classmethod
def eqs_over_hypercube_slow(cls, k, indeterminates):
if k > 5:
raise ValueError("k>5 isn't supported")
xs = indeterminates[:k]
n = 1 << k
eqs = [1] * n
for i in range(n):
bs = bits_le_with_width(i, k)
eqs[i] = reduce(lambda v, j: v * ((1 - xs[j]) * (1 - bs[j]) + xs[j] * bs[j]), range(k), 1)
return eqs

@classmethod
def from_coeffs(cls, coeffs, num_var):
Expand Down Expand Up @@ -58,7 +71,6 @@ def compute_coeffs_from_evals(cls, f_evals):

@classmethod
def evaluate_from_evals(cls, evals, zs):
z = len(zs)
f = evals

half = len(f) >> 1
Expand All @@ -69,6 +81,20 @@ def evaluate_from_evals(cls, evals, zs):
half >>= 1
return f[0]

@classmethod
def evaluate_from_evals_2(cls, evals, zs):
k = len(zs)
f = evals

half = len(f) >> 1
for i in range(k):
u = zs[k-i-1]

f = [(1-u) * f[j] + u * f[j+half] for j in range(half)]
half >>= 1

return f[0]

def evaluate(self, zs: list):
"""
Evaluate the MLE polynomial at the given points.
Expand Down
23 changes: 18 additions & 5 deletions src/mle2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,45 @@ def test_evaluate(self):
evals = [Fp(1), Fp(2), Fp(3), Fp(4)]
mle = MLEPolynomial(evals, 2)
result = mle.evaluate([Fp(1), Fp(2)])
self.assertIsInstance(result, Fp)
# self.assertIsInstance(result, Fp)

def test_evaluate_from_coeffs(self):
coeffs = [Fp(1), Fp(2), Fp(3), Fp(4)]
result = MLEPolynomial.evaluate_from_coeffs(coeffs, [Fp(1), Fp(2)])
self.assertIsInstance(result, Fp)
# self.assertIsInstance(result, Fp)

def test_eval_from_coeffs(self):
coeffs = [Fp(1), Fp(2), Fp(3), Fp(4)]
result = MLEPolynomial.eval_from_coeffs(coeffs, Fp(2))
self.assertIsInstance(result, Fp)
# self.assertIsInstance(result, Fp)

def test_decompose_by_div(self):
evals = [Fp(i) for i in range(8)]
mle = MLEPolynomial(evals, 3)
point = [Fp(1), Fp(2), Fp(3)]
quotients, remainder = mle.decompose_by_div(point)
self.assertEqual(len(quotients), 3)
self.assertIsInstance(remainder, Fp)
# self.assertIsInstance(remainder, Fp)

def test_decompose_by_div_from_coeffs(self):
coeffs = [Fp(i) for i in range(8)]
point = [Fp(1), Fp(2), Fp(3)]
quotients, remainder = MLEPolynomial.decompose_by_div_from_coeffs(coeffs, point)
self.assertEqual(len(quotients), 3)
self.assertIsInstance(remainder, Fp)
# self.assertIsInstance(remainder, Fp)

def test_eq_poly_vec(self):
point = [Fp(i) for i in range(4)]
res1 = MLEPolynomial.eqs_over_hypercube(point)
res2 = MLEPolynomial.eqs_over_hypercube_slow(4, point)
self.assertEqual(res1, res2)

def evaluate_from_evals(self):
point = [Fp(i) for i in range(4)]
evals = [Fp(i) for i in range(16)]
res1 = MLEPolynomial.evaluate_from_evals(evals, point)
res2 = MLEPolynomial.evaluate_from_evals_2(evals, point)
self.assertEqual(res1, res2)

if __name__ == '__main__':
unittest.main()
28 changes: 28 additions & 0 deletions src/unipolynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,34 @@ def interpolate(cls, evals, domain):
coeffs = cls.compute_coeffs_from_evals_fast(evals, domain)
return cls(coeffs)

# barycentric interpolation
@classmethod
def barycentric_weights(cls, D):
n = len(D)
weights = [1] * n
for i in range(n):
# weights[i] = product([(D[i] - D[j]) if i !=j else Fp(1) for j in range(n)])
for j in range(n):
if i==j:
weights[i] *= 1
continue
weights[i] *= (D[i] - D[j])
weights[i] = 1/weights[i]
return weights

@classmethod
def uni_eval_from_evals(cls, evals, z, D):
n = len(evals)
if n != len(D):
raise ValueError("Domain size should be equal to the length of evaluations")
if z in D:
return evals[D.index(z)]
weights = cls.barycentric_weights(D)
# print("weights={}".format(weights))
e_vec = [weights[i] / (z - D[i]) for i in range(n)]
numerator = sum([e_vec[i] * evals[i] for i in range(n)])
denominator = sum([e_vec[i] for i in range(n)])
return (numerator / denominator)

# Example usage
if __name__ == "__main__":
Expand Down
6 changes: 2 additions & 4 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def next_power_of_two(n):
if is_power_of_two(n):
return n
d = n
k = 0
k = 1
while d > 0:
d >>= 1
k += 1
k <<= 1
return k

def log_2(x):
Expand All @@ -56,5 +56,3 @@ def log_2(x):
x >>= 1 # Bit shift right (equivalent to integer division by 2)
result += 1
return result

is_power_of_two(15), next_power_of_two(15)
Loading

0 comments on commit 9764d05

Please sign in to comment.