-
Notifications
You must be signed in to change notification settings - Fork 0
/
FusionModel.py
123 lines (106 loc) · 4.82 KB
/
FusionModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import pennylane as qml
import torch
import torch.nn as nn
from math import pi
from Arguments import Arguments
args = Arguments()
def translator(net):
assert type(net) == type([])
updated_design = {}
q = net[0:7]
c = net[7:13]
p = net[13:]
# num of layer repetitions
updated_design['layer_repe'] = 5
# categories of single-qubit parametric gates
for i in range(args.n_qubits):
if q[i] == 0:
category = 'Rx'
else:
category = 'Ry'
updated_design['rot' + str(i)] = category
# categories and positions of entangled gates
pos_dict = {'00': 3, '01': 4, '10': 5, '11': 6}
for j in range(args.n_qubits-1):
if c[j] == 0:
category = 'IsingXX'
else:
category = 'IsingZZ'
if j <= 2:
position = pos_dict[str(p[2*j]) + str(p[2*j+1])]
else:
position = j + 1
updated_design['enta' + str(j)] = (category, [j, position])
updated_design['total_gates'] = len(q) + len(c)
return updated_design
qml.disable_return()
dev = qml.device("lightning.qubit", wires=args.n_qubits)
@qml.qnode(dev, interface="torch", diff_method="adjoint")
def quantum_net(q_input_features_flat, q_weights_rot_flat, q_weights_enta_flat, **kwargs):
current_design = kwargs['design']
q_input_features = q_input_features_flat.reshape(args.n_qubits, 3)
q_weights_rot = q_weights_rot_flat.reshape(current_design['layer_repe'], args.n_qubits)
q_weights_enta = q_weights_enta_flat.reshape(current_design['layer_repe'], args.n_qubits-1)
for layer in range(current_design['layer_repe']):
# data reuploading
for i in range(args.n_qubits):
qml.Rot(*q_input_features[i], wires=i)
# single-qubit parametric gates and entangled gates
for j in range(args.n_qubits-1):
if current_design['rot' + str(j)] == 'Rx':
qml.RX(q_weights_rot[layer][j], wires=j)
else:
qml.RY(q_weights_rot[layer][j], wires=j)
if current_design['enta' + str(j)][0] == 'IsingXX':
qml.IsingXX(q_weights_enta[layer][j], wires=current_design['enta' + str(j)][1])
else:
qml.IsingZZ(q_weights_enta[layer][j], wires=current_design['enta' + str(j)][1])
if current_design['rot' + str(args.n_qubits-1)] == 'Rx':
qml.RX(q_weights_rot[layer][-1], wires=args.n_qubits-1)
else:
qml.RY(q_weights_rot[layer][-1], wires=args.n_qubits-1)
return [qml.expval(qml.PauliZ(i)) for i in range(args.n_qubits)]
class QuantumLayer(nn.Module):
def __init__(self, arguments, design):
super(QuantumLayer, self).__init__()
self.args = arguments
self.design = design
self.q_params_rot = nn.Parameter(pi * torch.rand(self.design['layer_repe'] * self.args.n_qubits))
self.q_params_enta = nn.Parameter(pi * torch.rand(self.design['layer_repe'] * (self.args.n_qubits-1)))
def forward(self, input_features):
q_out = torch.Tensor(0, self.args.n_qubits)
q_out = q_out.to(self.args.device)
for elem in input_features:
q_out_elem = quantum_net(elem, self.q_params_rot, self.q_params_enta, design=self.design).float().unsqueeze(0)
q_out = torch.cat((q_out, q_out_elem))
return q_out
class QNet(nn.Module):
def __init__(self, arguments, design):
super(QNet, self).__init__()
self.args = arguments
self.design = design
self.ClassicalLayer_a = nn.RNN(self.args.a_insize, self.args.a_hidsize)
self.ClassicalLayer_v = nn.RNN(self.args.v_insize, self.args.v_hidsize)
self.ClassicalLayer_t = nn.RNN(self.args.t_insize, self.args.t_hidsize)
self.ProjLayer_a = nn.Linear(self.args.a_hidsize, self.args.a_hidsize)
self.ProjLayer_v = nn.Linear(self.args.v_hidsize, self.args.v_hidsize)
self.ProjLayer_t = nn.Linear(self.args.t_hidsize, self.args.t_hidsize)
self.QuantumLayer = QuantumLayer(self.args, self.design)
self.Regressor = nn.Linear(self.args.n_qubits, 1)
for name, param in self.named_parameters():
if "QuantumLayer" not in name:
param.requires_grad = False
def forward(self, x_a, x_v, x_t):
x_a = torch.permute(x_a, (1, 0, 2))
x_v = torch.permute(x_v, (1, 0, 2))
x_t = torch.permute(x_t, (1, 0, 2))
a_h = self.ClassicalLayer_a(x_a)[0][-1]
v_h = self.ClassicalLayer_v(x_v)[0][-1]
t_h = self.ClassicalLayer_t(x_t)[0][-1]
a_o = torch.relu(self.ProjLayer_a(a_h))
v_o = torch.sigmoid(self.ProjLayer_v(v_h)) * pi
t_o = torch.sigmoid(self.ProjLayer_t(t_h)) * pi
x_p = torch.cat((a_o, v_o, t_o), 1)
exp_val = self.QuantumLayer(x_p)
output = torch.tanh(self.Regressor(exp_val).squeeze(dim=1)) * 3
return output