Skip to content

Commit

Permalink
Merge pull request #34 from NillionNetwork/feat/simple-shuffle
Browse files Browse the repository at this point in the history
feat: add simple shuffle example
  • Loading branch information
oceans404 authored Sep 4, 2024
2 parents 7bb25b7 + d66c75e commit 47b4d0b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 1 deletion.
7 changes: 6 additions & 1 deletion nada-project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,9 @@ prime_size = 128
[[programs]]
path = "src/shuffle.py"
name = "shuffle"
prime_size = 128
prime_size = 128

[[programs]]
path = "src/shuffle_simple.py"
name = "shuffle_simple"
prime_size = 128
21 changes: 21 additions & 0 deletions src/shuffle_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from nada_dsl import SecretInteger

import nada_numpy as na
from nada_numpy import shuffle


def nada_main():

# Note:
# The current shuffle operation only supports vectors with
# a power-of-two size, e.g., 2, 4, 8, 16, 32, ...
size=4

parties = na.parties(2)
nums = na.array([size], parties[0], "num", SecretInteger)

shuffled_nums = shuffle(nums)

return (
na.output(shuffled_nums, parties[1], "shuffled_num")
)
45 changes: 45 additions & 0 deletions tests/shuffle_simple_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from nada_test import nada_test

inputs = {"num_0": 10, "num_1": 20, "num_2": 30, "num_3": 40}

# Test that the shuffled array contains the same values as the input, regardless of order
@nada_test(program="shuffle_simple")
def shuffle_simple_test_same_values():
outputs = yield inputs

shuffled_nums = [
outputs["shuffled_num_0"],
outputs["shuffled_num_1"],
outputs["shuffled_num_2"],
outputs["shuffled_num_3"]
]

# Assert that the sorted output contains the same values as the sorted input
assert sorted(shuffled_nums) == sorted([
inputs["num_0"],
inputs["num_1"],
inputs["num_2"],
inputs["num_3"]
]), "Test failed: the shuffled array contains different values."

# Test that the resulting shuffled array is not in the same order as the input
@nada_test(program="shuffle_simple")
def shuffle_simple_test_not_same_order():
outputs = yield inputs

original_nums = [
inputs["num_0"],
inputs["num_1"],
inputs["num_2"],
inputs["num_3"]
]

shuffled_nums = [
outputs["shuffled_num_0"],
outputs["shuffled_num_1"],
outputs["shuffled_num_2"],
outputs["shuffled_num_3"]
]

# Assert that the shuffled numbers are NOT in the same order as the input
assert shuffled_nums != original_nums, "Test failed: the order did not change"
12 changes: 12 additions & 0 deletions tests/shuffle_simple_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
program: shuffle_simple
inputs:
num_0: 3
num_1: 3
num_2: 3
num_3: 3
expected_outputs:
shuffled_num_0: 3
shuffled_num_1: 3
shuffled_num_2: 3
shuffled_num_3: 3

0 comments on commit 47b4d0b

Please sign in to comment.