diff --git a/.gitignore b/.gitignore index 7bbc71c0..b660602d 100644 --- a/.gitignore +++ b/.gitignore @@ -99,3 +99,6 @@ ENV/ # mypy .mypy_cache/ + +# vscode +.vscode diff --git a/mpyc/random.py b/mpyc/random.py index 1cfb5d81..7592da06 100644 --- a/mpyc/random.py +++ b/mpyc/random.py @@ -322,3 +322,81 @@ def uniform(sectype, a, b): s = math.copysign(1, b - a) return a + _randbelow(sectype, round(abs(a - b) * 2**f)) * s * 2**-f + + +@asyncoro.mpc_coro +async def np_random_unit_vector(sectype, n): + """Uniformly random secret rotation of [1] + [0]*(n-1). + + Expected number of secret random bits needed is ceil(log_2 n) + c, + with c a small constant, c < 3. + """ + + await runtime.returnType((sectype.array, True, (n,))) + + if n == 1: + return runtime.np_fromlist([sectype(1)]) + + b = n - 1 + k = b.bit_length() + x = runtime.np_random_bits(sectype, k) + + i = k - 1 + u = runtime.np_fromlist([x[i], 1 - x[i]]) + while i: + i -= 1 + if (b >> i) & 1: + v = x[i] * u + v = runtime.np_hstack((v, u - v)) + u = v + elif await runtime.output(u[0] * x[i]): # TODO: mul_public + # restart, keeping unused secret random bits x[:i] + x = runtime.np_hstack((x[:i], runtime.np_random_bits(sectype, k - i))) + i = k - 1 + u = runtime.np_fromlist([x[i], 1 - x[i]]) + else: + v = x[i] * u[1:] + v = runtime.np_hstack((v, u[1:] - v)) + u = runtime.np_hstack((u[:1], v)) + return u + + +async def np_shuffle(a, axis=None): + """Shuffle numpy-like array x secretly in-place, and return None. + + Given array x may contain public or secret elements. + """ + sectype = type(a).sectype + + if len(a.shape) > 2: + raise ValueError("Can only shuffle 1D and 2D arrays") + + if axis is None: + axis = 0 + + if axis not in (0,1,-1): + raise ValueError("Invalid axis") + + x = runtime.np_copy(a) + + if axis != 0: + x = runtime.np_transpose(x) + + n = x.shape[0] + + for i in range(n - 1): + u = runtime.np_transpose(np_random_unit_vector(sectype, n - i)) + x_u = runtime.np_matmul(u, x[i:]) + if len(x.shape) > 1: + d = runtime.np_outer(u, (x[i] - x_u)) + x = runtime.np_vstack((x[:i, ...], runtime.np_add(x[i:, ...], d))) + else: + d = u * (x[i] - x_u) + x = runtime.np_hstack((x[:i, ...], runtime.np_add(x[i:, ...], d))) + x = runtime.np_update(x, i, x_u) + + if axis != 0: + x = runtime.np_transpose(x) + + x = await runtime.gather(x) + runtime.np_update(a, range(a.shape[0]), x) diff --git a/tests/test_random.py b/tests/test_random.py index 489ca276..3a50abf2 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -1,8 +1,9 @@ import unittest from mpyc.runtime import mpc -from mpyc.random import (getrandbits, randrange, random_unit_vector, randint, - shuffle, random_permutation, random_derangement, +from mpyc.random import (getrandbits, randrange, random_unit_vector, np_random_unit_vector, randint, + shuffle, np_shuffle, random_permutation, random_derangement, choice, choices, sample, random, uniform) +from mpyc.numpy import np class Arithmetic(unittest.TestCase): @@ -77,6 +78,56 @@ def test_secint(self): self.assertLessEqual(max(x) // 1000, 1009) self.assertEqual(sum(x) % 1000, 0) + @unittest.skipIf(not np, 'NumPy not available or inside MPyC disabled') + def test_np_shuffle(self): + secint = mpc.SecInt() + x = secint.array(np.arange(8)) + mpc.run(np_shuffle(x)) + x = mpc.run(mpc.output(x)) + self.assertSetEqual(set(x), set(np.arange(8))) + + x = secint.array(np.array([np.arange(8)])) + mpc.run(np_shuffle(x)) + x = mpc.run(mpc.output(x)) + self.assertTrue((x == np.array([np.arange(8)])).all()) + + x_init = np.arange(8).reshape(2,4) + x = secint.array(x_init) + mpc.run(np_shuffle(x)) + x = mpc.run(mpc.output(x)) + self.assertIn(set(x[0,:]), [set(x_init[i,:]) for i in range(x_init.shape[0])]) + + x = secint.array(x_init) + mpc.run(np_shuffle(x, axis=1)) + x = mpc.run(mpc.output(x)) + self.assertIn(set(x[:,0]), [set(x_init[:,j]) for j in range(x_init.shape[1])]) + + x = secint.array(x_init) + with self.assertRaises(ValueError): + mpc.run(np_shuffle(x, 3)) + + x = secint.array(np.ones((8,8,8))) + with self.assertRaises(ValueError): + mpc.run(np_shuffle(x)) + + @unittest.skipIf(not np, 'NumPy not available or inside MPyC disabled') + def test_np_random_unit_vector(self): + secint = mpc.SecInt() + x = mpc.run(mpc.output(np_random_unit_vector(secint, 4))) + self.assertEqual(sum(x), 1) + + secfxp = mpc.SecFxp() + x = mpc.run(mpc.output(np_random_unit_vector(secfxp, 3))) + self.assertEqual(int(sum(x)), 1) + + secfld = mpc.SecFld(256) + x = mpc.run(mpc.output(np_random_unit_vector(secfld, 2))) + self.assertEqual(int(sum(x)), 1) + + secfld = mpc.SecFld(257) + x = mpc.run(mpc.output(np_random_unit_vector(secfld, 1))) + self.assertEqual(int(sum(x)), 1) + def test_secfxp(self): secfxp = mpc.SecFxp() a = getrandbits(secfxp, 10)