· 4 years ago · Dec 24, 2020, 12:44 PM
1#include "examples.h"
2#include <iostream>
3
4using namespace std;
5using namespace seal;
6
7void get_ciphertext(shared_ptr<SEALContext> context, Ciphertext &input)
8{
9 auto &context_data = *context->get_context_data(input.parms_id());
10 auto &parms = context_data.parms();
11 auto &coeff_modulus = parms.coeff_modulus();
12 double scale = pow(2.0, 30);
13 print_parameters(context);
14 cout << endl;
15
16 KeyGenerator keygen(context);
17 auto public_key = keygen.public_key();
18 auto secret_key = keygen.secret_key();
19 auto relin_keys = keygen.relin_keys_local();
20 Encryptor encryptor(context, public_key);
21 Evaluator evaluator(context);
22 Decryptor decryptor(context, secret_key);
23 CKKSEncoder encoder(context);
24
25 cout << input.data() << endl;
26 Plaintext plain;
27 vector<double> vec;
28 decryptor.decrypt(input,plain);
29 encoder.decode(plain, vec);
30 cout << " The input obtained " << endl; // the decrypted output is different from the decrypted result in main program.
31 print_vector(vec, 3, 8);
32}
33
34int main()
35{
36 EncryptionParameters parms(scheme_type::CKKS);
37 size_t poly_modulus_degree = 32768;
38 parms.set_poly_modulus_degree(poly_modulus_degree);
39 parms.set_coeff_modulus(
40 CoeffModulus::Create(poly_modulus_degree, { 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
41 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30 }));
42
43 double scale = pow(2.0, 30);
44 auto context = SEALContext::Create(parms);
45 print_parameters(context);
46 cout << endl;
47 KeyGenerator keygen(context);
48 auto public_key = keygen.public_key();
49 auto secret_key = keygen.secret_key();
50 auto relin_keys = keygen.relin_keys_local();
51 Encryptor encryptor(context, public_key);
52 Evaluator evaluator(context);
53 Decryptor decryptor(context, secret_key);
54 CKKSEncoder encoder(context);
55 size_t slot_count = encoder.slot_count();
56 cout << "Number of slots: " << slot_count << endl;
57
58 double input = 0.769366402;
59 double weight = -0.406239561;
60 double bias = -9.158888066;
61
62 Plaintext input_plain, weight_plain, bias_plain;
63 Ciphertext input_cipher, weight_cipher, bias_cipher, node_output, temp1;
64
65 encoder.encode(input, scale, input_plain);
66 encoder.encode(weight, scale, weight_plain);
67 encryptor.encrypt(input_plain, input_cipher);
68 encryptor.encrypt(weight_plain, weight_cipher);
69 evaluator.multiply(input_cipher, weight_cipher, temp1);
70 evaluator.relinearize_inplace(temp1, relin_keys);
71 evaluator.rescale_to_next_inplace(temp1);
72
73 encoder.encode(bias, temp1.scale(), bias_plain);
74 encryptor.encrypt(bias_plain, bias_cipher);
75 evaluator.mod_switch_to_inplace(bias_cipher, temp1.parms_id());
76 cout << " bias scale after mod switch " << log2(bias_cipher.scale()) << endl;
77 evaluator.add(temp1, bias_cipher, node_output);
78
79 cout << node_output.data();
80 cout << endl;
81
82 Plaintext pl;
83 vector<double> vect;
84 decryptor.decrypt(node_output, pl);
85 encoder.decode(pl, vect);
86 cout << " The input sent to the function " << endl;
87 print_vector(vect, 3, 8);
88 cout << endl;
89 get_ciphertext(context, node_output);
90
91 return 0;
92}