· 5 years ago · Mar 05, 2020, 07:30 AM
1from typing import Tuple
2
3import numpy as np
4from seal import Plaintext, PublicKey, DoubleVector, Ciphertext, CKKSEncoder, SEALContext, Encryptor, \
5 Decryptor, Evaluator, SecretKey
6
7
8class CipherMatrix:
9 def __init__(self, rows: int, cols: int):
10 self.rows = rows
11 self.cols = cols
12 self.data = tuple(Ciphertext() for i in range(rows))
13
14 def __setitem__(self, key, value: Ciphertext):
15 self.data = self.data[:key] + (value,) + self.data[key + 1:]
16
17 def __getitem__(self, item) -> Ciphertext:
18 return self.data[item]
19
20
21class Matrix:
22 def __init__(self, rows: int, cols: int, data):
23 self.rows = rows
24 self.cols = cols
25 if data:
26 self.data
27 self.data = tuple(DoubleVector() for i in range(rows))
28 else:
29 self.data = tuple(DoubleVector() for i in range(rows))
30
31 @classmethod
32 def from_numpy_array(cls, array: np.array):
33 if array.ndim != 2:
34 raise ValueError('array is not a 2D matrix')
35 rows = array.shape[0]
36 cols = array.shape[1]
37
38 cls(rows=rows, cols=cols, data=array)
39
40 def __setitem__(self, key, value: list):
41 self.data = self.data[:key] + (value,) + self.data[key + 1:]
42
43 def __getitem__(self, item):
44 return self.data[item]
45
46
47class MatrixEvaluator:
48 def __init__(self, context: SEALContext, scale: int, public_key: PublicKey, secret_key: SecretKey):
49 self.scale = scale
50 self.context = context
51 self.encoder = CKKSEncoder(context)
52 self.evaluator = Evaluator(context)
53 self.encryptor = Encryptor(context, public_key)
54 self.decryptor = Decryptor(context, secret_key)
55
56 def encrypt(self, matrix: np.array):
57 if matrix.ndim != 2:
58 raise ValueError('matrix is not a 2D matrix')
59
60 rows = matrix.shape[0]
61 cols = matrix.shape[1]
62
63 cipher_matrix = CipherMatrix(rows=rows, cols=cols)
64
65 for i in range(rows):
66 row = DoubleVector(matrix[i:])
67 encoded_row = Plaintext()
68 self.encoder.encode(row, self.scale, encoded_row)
69 self.encryptor.encrypt(encoded_row, cipher_matrix[i])
70
71 return cipher_matrix
72
73 def decrypt(self, cipher_matrix: CipherMatrix):
74 matrix = []
75
76 for i in range(cipher_matrix.rows):
77 row = DoubleVector()
78 encoded_row = Plaintext()
79 self.decryptor.decrypt(cipher_matrix[i], encoded_row)
80 self.encoder.decode(encoded_row, row)
81 matrix.append(row)
82
83 return matrix
84
85 def add(self, matrix_a: CipherMatrix, matrix_b: CipherMatrix) -> CipherMatrix:
86 self.validate_same_dimension(matrix_a, matrix_b)
87
88 result_matrix = CipherMatrix(rows=matrix_a.rows, cols=matrix_a.cols)
89 for i in range(matrix_a.rows):
90 self.evaluator.add(matrix_a[i], matrix_b[i], result_matrix[i])
91
92 return result_matrix
93
94 def add_plain(self, matrix_a: CipherMatrix, matrix_b: Matrix) -> CipherMatrix:
95 self.validate_same_dimension(matrix_a, matrix_b)
96
97 result_matrix = CipherMatrix(rows=matrix_a.rows, cols=matrix_a.cols)
98
99 for i in range(matrix_a.rows):
100 row = matrix_b[i]
101 encoded_row = Plaintext()
102 self.encoder.encode(row, self.scale, encoded_row)
103 self.evaluator.mod_switch_to_inplace(encoded_row, matrix_a[i].parms_id())
104 self.evaluator.add_plain(matrix_a[i], matrix_b[i], result_matrix[i])
105
106 return result_matrix
107
108 def dot(self):
109 pass
110
111 @staticmethod
112 def validate_same_dimension(matrix_a, matrix_b):
113 if matrix_a.rows != matrix_b.rows or matrix_a.cols != matrix_b.cols:
114 raise ArithmeticError("Matrices aren't of the same dimension")
115
116
117class RescalingEvaluator:
118 def __init__(self, context: SEALContext, scale: int, public_key: PublicKey, secret_key: SecretKey):
119 self.scale = scale
120 self.context = context
121 self.encoder = CKKSEncoder(context)
122 self.evaluator = Evaluator(context)
123 self.encryptor = Encryptor(context, public_key)
124 self.decryptor = Decryptor(context, secret_key)