Skip to content

Commit

Permalink
fix(mmcs): fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ahy231 committed Nov 27, 2024
1 parent 2114b20 commit 3d042a5
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
38 changes: 27 additions & 11 deletions src/mmcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def commit(cls, vecs, debug=False):
layers.append([cls.compress((layers[i-1][j], layers[i-1][j+1]), debug) for j in range(0, len(layers[i-1]), 2)])
layers[-1] = [cls.compress((layers[-1][j], cls.hash(vecs[i][j])), debug) for j in range(len(layers[-1]))]

for i in range(log_2(min_height)):
for i in range(len(vecs), len(vecs) + log_2(min_height)):
layers.append([cls.compress((layers[i-1][j], layers[i-1][j+1]), debug) for j in range(0, len(layers[i-1]), 2)])
layers[-1] = [cls.compress((layers[-1][j], cls.default_digest), debug) for j in range(len(layers[-1]))]
print("layers[-1]:", layers[-1])

return {
'layers': layers,
Expand All @@ -48,28 +49,43 @@ def commit(cls, vecs, debug=False):
@classmethod
def open(cls, index, prover_data, debug=False):
assert cls.configured, "MMCS is not configured"

layers = prover_data['layers']
if debug: print(layers)
if debug:
print("layers:")
for layer in layers:
print(layer)
vecs = prover_data['vecs']
if debug: print(vecs)
if debug:
print("vecs:")
for vec in vecs:
print(vec)

height = len(layers)
openings = [vecs[i][index >> (32 - len(vecs[i]))] for i in range(height)]
proof = [layers[i][(index >> (32 - len(layers[i]))) ^ 1] for i in range(height - 1)]
openings = [vecs[i][index >> (32 - log_2(len(vecs[i])))] for i in range(len(vecs))]
proof = [layers[i][(index >> (32 - log_2(len(layers[i])))) ^ 1] for i in range(len(layers) - 1)]
root = layers[-1][0]
if debug: print(openings, proof, root)
if debug: print("openings:", openings)
if debug: print("proof:", proof)
if debug: print("root:", root)
return openings, proof, root

@classmethod
def verify(cls, index, openings, proof, root, debug=False):
index >>= (32 - (1 << (len(openings) - 1)))
index >>= (32 - len(proof))
assert index < 1 << len(proof), f"index {index} is out of bounds"
if debug: print("index:", index)
expected = cls.hash(openings[0])
index >>= 1
for i in range(1, len(openings)):
if debug: print("expected:", expected)
for i in range(1, 1 + len(proof)):
if debug: print("index:", index)
if index & 1:
expected = cls.compress((proof[i-1], expected), debug)
else:
expected = cls.compress((expected, proof[i-1]), debug)
expected = cls.compress((expected, cls.hash(openings[i])), debug)
if i > len(openings) - 1:
expected = cls.compress((expected, cls.default_digest), debug)
else:
expected = cls.compress((expected, cls.hash(openings[i])), debug)
index >>= 1
if debug: print("expected:", expected)
assert expected == root, f"expected {expected}, root {root}"
9 changes: 5 additions & 4 deletions tests/test_mmcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
class TestMMCS(TestCase):
def setUp(self):
def hash(x): return x
def compress(x): return x[0] + x[1]
def compress(x): return x[0] - x[1]
MMCS.configure(hash, compress)

def test_mmcs(self):
evals = [[randint(0, 2**32-1) for _ in range(1 << i)] for i in range(4)]
evals = [[randint(0, 2**32-1) for _ in range(4 * (1 << i))] for i in range(4)]
evals = list(reversed(evals))
prover_data = MMCS.commit(evals, debug=False)
openings, proof, root = MMCS.open(0, prover_data, debug=False)
MMCS.verify(0, openings, proof, root, debug=False)
idx = randint(0, 2 ** 32 - 1)
openings, proof, root = MMCS.open(idx, prover_data, debug=True)
MMCS.verify(idx, openings, proof, root, debug=True)

if __name__ == "__main__":
main()

0 comments on commit 3d042a5

Please sign in to comment.