-
Notifications
You must be signed in to change notification settings - Fork 16
/
interactive_fheuint8.rs
180 lines (143 loc) · 6.67 KB
/
interactive_fheuint8.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
use itertools::Itertools;
use phantom_zone::*;
use rand::{thread_rng, Rng, RngCore};
fn function1(a: u8, b: u8, c: u8, d: u8) -> u8 {
((a + b) * c) * d
}
fn function1_fhe(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 {
&(&(a + b) * c) * d
}
fn function2(a: u8, b: u8, c: u8, d: u8) -> u8 {
(a * b) + (c * d)
}
fn function2_fhe(a: &FheUint8, b: &FheUint8, c: &FheUint8, d: &FheUint8) -> FheUint8 {
&(a * b) + &(c * d)
}
fn main() {
// Select parameter set
set_parameter_set(ParameterSelector::InteractiveLTE4Party);
// set application's common reference seed
let mut seed = [0u8; 32];
thread_rng().fill_bytes(&mut seed);
set_common_reference_seed(seed);
let no_of_parties = 4;
// Client side //
// Clients generate their private keys
let cks = (0..no_of_parties)
.into_iter()
.map(|_| gen_client_key())
.collect_vec();
// -- Round 1 -- //
// In round 1 each client generates their share for the collective public key.
// They send public key shares to each other with or out without the server.
// After receiving others public key shares clients independently aggregate
// the shares and produce the collective public key `pk`
let pk_shares = cks.iter().map(|k| collective_pk_share(k)).collect_vec();
// Clients aggregate public key shares to produce collective public key `pk`
let pk = aggregate_public_key_shares(&pk_shares);
// -- Round 2 -- //
// In round 2 each client generates server key share using the public key `pk`.
// Clients may also encrypt their private inputs using collective public key
// `pk`. Each client then uploads their server key share and private input
// ciphertexts to the server.
// Clients generate server key shares
//
// We assign user_id 0 to client 0, user_id 1 to client 1, user_id 2 to client
// 2, and user_id 4 to client 4.
//
// Note that `user_id`'s must be unique among the clients and must be less than
// total number of clients.
let server_key_shares = cks
.iter()
.enumerate()
.map(|(user_id, k)| collective_server_key_share(k, user_id, no_of_parties, &pk))
.collect_vec();
// Each client encrypts their private inputs using the collective public key
// `pk`. Unlike non-inteactive MPC protocol, private inputs are
// encrypted using collective public key.
let c0_a = thread_rng().gen::<u8>();
let c0_enc = pk.encrypt(vec![c0_a].as_slice());
let c1_a = thread_rng().gen::<u8>();
let c1_enc = pk.encrypt(vec![c1_a].as_slice());
let c2_a = thread_rng().gen::<u8>();
let c2_enc = pk.encrypt(vec![c2_a].as_slice());
let c3_a = thread_rng().gen::<u8>();
let c3_enc = pk.encrypt(vec![c3_a].as_slice());
// Clients upload their server key along with private encrypted inputs to
// the server
// Server side //
// Server receives server key shares from each client and proceeds to
// aggregate the shares and produce the server key
let server_key = aggregate_server_key_shares(&server_key_shares);
server_key.set_server_key();
// Server proceeds to extract clients private inputs
//
// Clients encrypt their FheUint8s inputs packed in a batched ciphertext.
// The server must extract clients private inputs from the batch ciphertext
// either (1) using `extract_at(index)` to extract `index`^{th} FheUint8
// ciphertext (2) or using `extract_all()` to extract all available FheUint8s
// (3) or using `extract_many(many)` to extract first `many` available FheUint8s
let c0_a_enc = c0_enc.extract_at(0);
let c1_a_enc = c1_enc.extract_at(0);
let c2_a_enc = c2_enc.extract_at(0);
let c3_a_enc = c3_enc.extract_at(0);
// Server proceeds to evaluate function1 on clients private inputs
let ct_out_f1 = function1_fhe(&c0_a_enc, &c1_a_enc, &c2_a_enc, &c3_a_enc);
// After server has finished evaluating the circuit on client private
// inputs, clients can proceed to multi-party decryption protocol to
// decrypt output ciphertext
// Client Side //
// In multi-party decryption protocol, client must come online, download the
// output ciphertext from the server, product "output ciphertext" dependent
// decryption share, and send it to other parties. After receiving
// decryption shares of other parties, clients independently aggregate the
// decrytion shares and decrypt the output ciphertext.
// Clients generate decryption shares
let decryption_shares = cks
.iter()
.map(|k| k.gen_decryption_share(&ct_out_f1))
.collect_vec();
// After receiving decryption shares from other parties, clients aggregate the
// shares and decrypt output ciphertext
let out_f1 = cks[0].aggregate_decryption_shares(&ct_out_f1, &decryption_shares);
// Check correctness of function1 output
let want_f1 = function1(c0_a, c1_a, c2_a, c3_a);
assert!(out_f1 == want_f1);
// --------
// Once server key is produced it can be re-used across different functions
// with different private client inputs for the same set of clients.
//
// Here we run `function2_fhe` for the same of clients but with different
// private inputs. Clients do not need to participate in the 2 round
// protocol again, instead they only upload their new private inputs to the
// server.
// Clients encrypt their private inputs
let c0_a = thread_rng().gen::<u8>();
let c0_enc = pk.encrypt(vec![c0_a].as_slice());
let c1_a = thread_rng().gen::<u8>();
let c1_enc = pk.encrypt(vec![c1_a].as_slice());
let c2_a = thread_rng().gen::<u8>();
let c2_enc = pk.encrypt(vec![c2_a].as_slice());
let c3_a = thread_rng().gen::<u8>();
let c3_enc = pk.encrypt(vec![c3_a].as_slice());
// Clients uploads only their new private inputs to the server
// Server side //
// Server receives private inputs from the clients, extracts them, and
// proceeds to evaluate `function2_fhe`
let c0_a_enc = c0_enc.extract_at(0);
let c1_a_enc = c1_enc.extract_at(0);
let c2_a_enc = c2_enc.extract_at(0);
let c3_a_enc = c3_enc.extract_at(0);
let ct_out_f2 = function2_fhe(&c0_a_enc, &c1_a_enc, &c2_a_enc, &c3_a_enc);
// Client side //
// Clients generate decryption shares for `ct_out_f2`
let decryption_shares = cks
.iter()
.map(|k| k.gen_decryption_share(&ct_out_f2))
.collect_vec();
// Clients aggregate decryption shares and decrypt `ct_out_f2`
let out_f2 = cks[0].aggregate_decryption_shares(&ct_out_f2, &decryption_shares);
// We check correctness of function2
let want_f2 = function2(c0_a, c1_a, c2_a, c3_a);
assert!(want_f2 == out_f2);
}