Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement np_shuffle #68

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,6 @@ ENV/

# mypy
.mypy_cache/

# vscode
.vscode
78 changes: 78 additions & 0 deletions mpyc/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,81 @@ def uniform(sectype, a, b):

s = math.copysign(1, b - a)
return a + _randbelow(sectype, round(abs(a - b) * 2**f)) * s * 2**-f


@asyncoro.mpc_coro
async def np_random_unit_vector(sectype, n):
"""Uniformly random secret rotation of [1] + [0]*(n-1).

Expected number of secret random bits needed is ceil(log_2 n) + c,
with c a small constant, c < 3.
"""

await runtime.returnType((sectype.array, True, (n,)))

if n == 1:
return runtime.np_fromlist([sectype(1)])

b = n - 1
k = b.bit_length()
x = runtime.np_random_bits(sectype, k)

i = k - 1
u = runtime.np_fromlist([x[i], 1 - x[i]])
while i:
i -= 1
if (b >> i) & 1:
v = x[i] * u
v = runtime.np_hstack((v, u - v))
u = v
elif await runtime.output(u[0] * x[i]): # TODO: mul_public
# restart, keeping unused secret random bits x[:i]
x = runtime.np_hstack((x[:i], runtime.np_random_bits(sectype, k - i)))
i = k - 1
u = runtime.np_fromlist([x[i], 1 - x[i]])
else:
v = x[i] * u[1:]
v = runtime.np_hstack((v, u[1:] - v))
u = runtime.np_hstack((u[:1], v))
return u


async def np_shuffle(a, axis=None):
"""Shuffle numpy-like array x secretly in-place, and return None.

Given array x may contain public or secret elements.
"""
sectype = type(a).sectype

if len(a.shape) > 2:
raise ValueError("Can only shuffle 1D and 2D arrays")

if axis is None:
axis = 0

if axis not in (0,1,-1):
raise ValueError("Invalid axis")

x = runtime.np_copy(a)

if axis != 0:
x = runtime.np_transpose(x)

n = x.shape[0]

for i in range(n - 1):
u = runtime.np_transpose(np_random_unit_vector(sectype, n - i))
x_u = runtime.np_matmul(u, x[i:])
if len(x.shape) > 1:
d = runtime.np_outer(u, (x[i] - x_u))
x = runtime.np_vstack((x[:i, ...], runtime.np_add(x[i:, ...], d)))
else:
d = u * (x[i] - x_u)
x = runtime.np_hstack((x[:i, ...], runtime.np_add(x[i:, ...], d)))
x = runtime.np_update(x, i, x_u)

if axis != 0:
x = runtime.np_transpose(x)

x = await runtime.gather(x)
runtime.np_update(a, range(a.shape[0]), x)
55 changes: 53 additions & 2 deletions tests/test_random.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import unittest
from mpyc.runtime import mpc
from mpyc.random import (getrandbits, randrange, random_unit_vector, randint,
shuffle, random_permutation, random_derangement,
from mpyc.random import (getrandbits, randrange, random_unit_vector, np_random_unit_vector, randint,
shuffle, np_shuffle, random_permutation, random_derangement,
choice, choices, sample, random, uniform)
from mpyc.numpy import np


class Arithmetic(unittest.TestCase):
Expand Down Expand Up @@ -77,6 +78,56 @@ def test_secint(self):
self.assertLessEqual(max(x) // 1000, 1009)
self.assertEqual(sum(x) % 1000, 0)

@unittest.skipIf(not np, 'NumPy not available or inside MPyC disabled')
def test_np_shuffle(self):
secint = mpc.SecInt()
x = secint.array(np.arange(8))
mpc.run(np_shuffle(x))
x = mpc.run(mpc.output(x))
self.assertSetEqual(set(x), set(np.arange(8)))

x = secint.array(np.array([np.arange(8)]))
mpc.run(np_shuffle(x))
x = mpc.run(mpc.output(x))
self.assertTrue((x == np.array([np.arange(8)])).all())

x_init = np.arange(8).reshape(2,4)
x = secint.array(x_init)
mpc.run(np_shuffle(x))
x = mpc.run(mpc.output(x))
self.assertIn(set(x[0,:]), [set(x_init[i,:]) for i in range(x_init.shape[0])])

x = secint.array(x_init)
mpc.run(np_shuffle(x, axis=1))
x = mpc.run(mpc.output(x))
self.assertIn(set(x[:,0]), [set(x_init[:,j]) for j in range(x_init.shape[1])])

x = secint.array(x_init)
with self.assertRaises(ValueError):
mpc.run(np_shuffle(x, 3))

x = secint.array(np.ones((8,8,8)))
with self.assertRaises(ValueError):
mpc.run(np_shuffle(x))

@unittest.skipIf(not np, 'NumPy not available or inside MPyC disabled')
def test_np_random_unit_vector(self):
secint = mpc.SecInt()
x = mpc.run(mpc.output(np_random_unit_vector(secint, 4)))
self.assertEqual(sum(x), 1)

secfxp = mpc.SecFxp()
x = mpc.run(mpc.output(np_random_unit_vector(secfxp, 3)))
self.assertEqual(int(sum(x)), 1)

secfld = mpc.SecFld(256)
x = mpc.run(mpc.output(np_random_unit_vector(secfld, 2)))
self.assertEqual(int(sum(x)), 1)

secfld = mpc.SecFld(257)
x = mpc.run(mpc.output(np_random_unit_vector(secfld, 1)))
self.assertEqual(int(sum(x)), 1)

def test_secfxp(self):
secfxp = mpc.SecFxp()
a = getrandbits(secfxp, 10)
Expand Down