diff --git a/src/mmcs.py b/src/mmcs.py index cf6868b..3fa56ac 100644 --- a/src/mmcs.py +++ b/src/mmcs.py @@ -36,9 +36,10 @@ def commit(cls, vecs, debug=False): layers.append([cls.compress((layers[i-1][j], layers[i-1][j+1]), debug) for j in range(0, len(layers[i-1]), 2)]) layers[-1] = [cls.compress((layers[-1][j], cls.hash(vecs[i][j])), debug) for j in range(len(layers[-1]))] - for i in range(log_2(min_height)): + for i in range(len(vecs), len(vecs) + log_2(min_height)): layers.append([cls.compress((layers[i-1][j], layers[i-1][j+1]), debug) for j in range(0, len(layers[i-1]), 2)]) layers[-1] = [cls.compress((layers[-1][j], cls.default_digest), debug) for j in range(len(layers[-1]))] + print("layers[-1]:", layers[-1]) return { 'layers': layers, @@ -48,28 +49,43 @@ def commit(cls, vecs, debug=False): @classmethod def open(cls, index, prover_data, debug=False): assert cls.configured, "MMCS is not configured" + layers = prover_data['layers'] - if debug: print(layers) + if debug: + print("layers:") + for layer in layers: + print(layer) vecs = prover_data['vecs'] - if debug: print(vecs) + if debug: + print("vecs:") + for vec in vecs: + print(vec) - height = len(layers) - openings = [vecs[i][index >> (32 - len(vecs[i]))] for i in range(height)] - proof = [layers[i][(index >> (32 - len(layers[i]))) ^ 1] for i in range(height - 1)] + openings = [vecs[i][index >> (32 - log_2(len(vecs[i])))] for i in range(len(vecs))] + proof = [layers[i][(index >> (32 - log_2(len(layers[i])))) ^ 1] for i in range(len(layers) - 1)] root = layers[-1][0] - if debug: print(openings, proof, root) + if debug: print("openings:", openings) + if debug: print("proof:", proof) + if debug: print("root:", root) return openings, proof, root @classmethod def verify(cls, index, openings, proof, root, debug=False): - index >>= (32 - (1 << (len(openings) - 1))) + index >>= (32 - len(proof)) + assert index < 1 << len(proof), f"index {index} is out of bounds" + if debug: print("index:", index) expected = cls.hash(openings[0]) - index >>= 1 - for i in range(1, len(openings)): + if debug: print("expected:", expected) + for i in range(1, 1 + len(proof)): + if debug: print("index:", index) if index & 1: expected = cls.compress((proof[i-1], expected), debug) else: expected = cls.compress((expected, proof[i-1]), debug) - expected = cls.compress((expected, cls.hash(openings[i])), debug) + if i > len(openings) - 1: + expected = cls.compress((expected, cls.default_digest), debug) + else: + expected = cls.compress((expected, cls.hash(openings[i])), debug) index >>= 1 + if debug: print("expected:", expected) assert expected == root, f"expected {expected}, root {root}" diff --git a/tests/test_mmcs.py b/tests/test_mmcs.py index 35e4dc7..e989c6c 100644 --- a/tests/test_mmcs.py +++ b/tests/test_mmcs.py @@ -10,15 +10,16 @@ class TestMMCS(TestCase): def setUp(self): def hash(x): return x - def compress(x): return x[0] + x[1] + def compress(x): return x[0] - x[1] MMCS.configure(hash, compress) def test_mmcs(self): - evals = [[randint(0, 2**32-1) for _ in range(1 << i)] for i in range(4)] + evals = [[randint(0, 2**32-1) for _ in range(4 * (1 << i))] for i in range(4)] evals = list(reversed(evals)) prover_data = MMCS.commit(evals, debug=False) - openings, proof, root = MMCS.open(0, prover_data, debug=False) - MMCS.verify(0, openings, proof, root, debug=False) + idx = randint(0, 2 ** 32 - 1) + openings, proof, root = MMCS.open(idx, prover_data, debug=True) + MMCS.verify(idx, openings, proof, root, debug=True) if __name__ == "__main__": main()