-
Notifications
You must be signed in to change notification settings - Fork 3
/
4_ckks_basics.jl
155 lines (126 loc) · 5.56 KB
/
4_ckks_basics.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
include("utilities.jl")
using SEAL
using Printf
"""
example_ckks_basics()
Perform some basic operations such as encryption/decryption, multiplication, addition etc. using the
CKKS scheme. This routine is based on the file `native/examples/4_ckks_basics.cpp` of the original
SEAL library and should yield the exact same output.
* [SEAL](https://github.com/microsoft/SEAL)
* [native/examples/4_ckks_basics.cpp](https://github.com/microsoft/SEAL/blob/master/native/examples/4_ckks_basics.cpp)
See also: [`example_ckks_basics`](@ref)
"""
function example_ckks_basics()
print_example_banner("Example: CKKS Basics")
parms = EncryptionParameters(SchemeType.ckks)
poly_modulus_degree = 8192
set_poly_modulus_degree!(parms, poly_modulus_degree)
set_coeff_modulus!(parms, coeff_modulus_create(poly_modulus_degree, [60, 40, 40, 60]))
initial_scale = 2.0^40
context = SEALContext(parms)
print_parameters(context)
println()
keygen = KeyGenerator(context)
public_key_ = PublicKey()
create_public_key!(public_key_, keygen)
secret_key_ = secret_key(keygen)
relin_keys_ = RelinKeys()
create_relin_keys!(relin_keys_, keygen)
encryptor = Encryptor(context, public_key_)
evaluator = Evaluator(context)
decryptor = Decryptor(context, secret_key_)
encoder = CKKSEncoder(context)
slot_count_ = slot_count(encoder)
println("Number of slots: ", slot_count_)
input = collect(range(0.0, 1.0, length=slot_count_))
println("Input vector:")
print_vector(input, 3, 7)
println("Evaluating polynomial PI*x^3 + 0.4x + 1 ...")
plain_coeff3 = Plaintext()
plain_coeff1 = Plaintext()
plain_coeff0 = Plaintext()
encode!(plain_coeff3, 3.14159265, initial_scale, encoder)
encode!(plain_coeff1, 0.4, initial_scale, encoder)
encode!(plain_coeff0, 1.0, initial_scale, encoder)
x_plain = Plaintext()
print_line(@__LINE__)
println("Encode input vectors.")
encode!(x_plain, input, initial_scale, encoder)
x1_encrypted = Ciphertext()
encrypt!(x1_encrypted, x_plain, encryptor)
x3_encrypted = Ciphertext()
print_line(@__LINE__)
println("Compute x^2 and relinearize:")
square!(x3_encrypted, x1_encrypted, evaluator)
relinearize_inplace!(x3_encrypted, relin_keys_, evaluator)
println(" + Scale of x^2 before rescale: ", log2(scale(x3_encrypted)), " bits")
print_line(@__LINE__)
println("Rescale x^2.")
rescale_to_next_inplace!(x3_encrypted, evaluator)
println(" + Scale of x^2 after rescale: ", log2(scale(x3_encrypted)), " bits")
print_line(@__LINE__)
println("Compute and rescale PI*x.")
x1_encrypted_coeff3 = Ciphertext()
multiply_plain!(x1_encrypted_coeff3, x1_encrypted, plain_coeff3, evaluator)
println(" + Scale of PI*x before rescale: ", log2(scale(x1_encrypted_coeff3)), " bits")
rescale_to_next_inplace!(x1_encrypted_coeff3, evaluator)
println(" + Scale of PI*x after rescale: ", log2(scale(x1_encrypted_coeff3)), " bits")
print_line(@__LINE__)
println("Compute, relinearize, and rescale (PI*x)*x^2.")
multiply_inplace!(x3_encrypted, x1_encrypted_coeff3, evaluator)
relinearize_inplace!(x3_encrypted, relin_keys_, evaluator)
println(" + Scale of PI*x^3 before rescale: ", log2(scale(x3_encrypted)), " bits")
rescale_to_next_inplace!(x3_encrypted, evaluator)
println(" + Scale of PI*x^3 after rescale: ", log2(scale(x3_encrypted)), " bits")
print_line(@__LINE__)
println("Compute and rescale 0.4*x.")
multiply_plain_inplace!(x1_encrypted, plain_coeff1, evaluator)
println(" + Scale of 0.4*x before rescale: ", log2(scale(x1_encrypted)), " bits")
rescale_to_next_inplace!(x1_encrypted, evaluator)
println(" + Scale of 0.4*x after rescale: ", log2(scale(x1_encrypted)), " bits")
println()
print_line(@__LINE__)
println("Parameters used by all three terms are different.")
ci_x3 = chain_index(get_context_data(context, parms_id(x3_encrypted)))
println(" + Modulus chain index for x3_encrypted: ", ci_x3)
ci_x1 = chain_index(get_context_data(context, parms_id(x1_encrypted)))
println(" + Modulus chain index for x1_encrypted: ", ci_x1)
ci_c0 = chain_index(get_context_data(context, parms_id(plain_coeff0)))
println(" + Modulus chain index for plain_coeff0: ", ci_c0)
println()
print_line(@__LINE__)
println("The exact scales of all three terms are different:")
@printf(" + Exact scale in PI*x^3: %.10f\n", scale(x3_encrypted))
@printf(" + Exact scale in 0.4*x: %.10f\n", scale(x1_encrypted))
@printf(" + Exact scale in 1: %.10f\n", scale(plain_coeff0))
println()
print_line(@__LINE__)
println("Normalize scales to 2^40.")
scale!(x3_encrypted, 2.0^40)
scale!(x1_encrypted, 2.0^40)
print_line(@__LINE__)
println("Normalize encryption parameters to the lowest level.")
last_parms_id = parms_id(x3_encrypted)
mod_switch_to_inplace!(x1_encrypted, last_parms_id, evaluator)
mod_switch_to_inplace!(plain_coeff0, last_parms_id, evaluator)
print_line(@__LINE__)
println("Compute PI*x^3 + 0.4*x + 1.")
encrypted_result = Ciphertext()
add!(encrypted_result, x3_encrypted, x1_encrypted, evaluator)
add_plain_inplace!(encrypted_result, plain_coeff0, evaluator)
plain_result = Plaintext()
print_line(@__LINE__)
println("Decrypt and decode PI*x^3 + 0.4x + 1.")
println(" + Expected result:")
true_result = similar(input)
for (i, x) in enumerate(input)
true_result[i] = (3.14159265 * x * x + 0.4) * x + 1
end
print_vector(true_result, 3, 7)
decrypt!(plain_result, encrypted_result, decryptor)
result = similar(input)
decode!(result, plain_result, encoder)
println(" + Computed result ...... Correct.")
print_vector(result, 3, 7)
return
end