-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
161 lines (136 loc) · 5.07 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import numpy as np
import jax.random
import jax.numpy as jnp
import jax.scipy.optimize
jax.config.update("jax_enable_x64", True)
from functools import partial
from qcnn import variational_classifier
#########################
######### Loss ##########
#########################
@partial(jax.jit, static_argnames=["nqubits", "loss_type", "gate_id"])
def single_loss(
weights: jnp.ndarray,
bias: jnp.ndarray,
ground_state: jnp.ndarray,
label: int,
nqubits: int,
loss_type: str,
gate_id: str
) -> float:
"""
Computes the loss of a single ground state.
Parameters:
weights (jnp.ndarray): Weights of the QCNN.
bias (jnp.ndarray): Bias of the QCNN.
ground_state (jnp.ndarray): The ground state as a 1D numpy array.
label (int): The label of the ground state.
nqubits (int): The number of qubits in the system.
loss_type (str): The type of loss to compute. Must be either "cross_entropy" or "mean_squares".
gate_id (str): The identifier for the type of gate sequence to apply.
Returns:
float: The computed loss.
"""
proj = variational_classifier(weights, bias, ground_state, nqubits, gate_id)
if loss_type == "cross_entropy":
cost = -jnp.log2(proj[label])
elif loss_type == "mean_squares":
cost = 1 + jnp.linalg.norm(proj)**2 - 2*proj[label]
return cost
@partial(jax.jit, static_argnames=["nqubits", "with_bias", "loss_type", "gate_id"])
def loss(
weights_and_bias: tuple[jnp.ndarray, jnp.ndarray] | jnp.ndarray,
ground_states: jnp.ndarray,
labels: jnp.ndarray,
nqubits: int,
with_bias: bool,
loss_type: str,
gate_id: str
) -> float:
"""
Computes the average loss over a set of ground states.
Parameters:
weights_and_bias (tuple[jnp.ndarray, jnp.ndarray] | jnp.ndarray): The weights and bias of the QCNN.
ground_states (jnp.ndarray): The ground states as a 2D numpy array.
labels (jnp.ndarray): The labels corresponding to each ground state.
nqubits (int): The number of qubits in the system.
with_bias (bool): Whether to include a bias term in the computation.
loss_type (str): The type of loss to compute. Must be either "cross_entropy" or "mean_squares".
gate_id (str): The identifier for the type of gate sequence to apply.
Returns:
float: The average computed loss across all ground states.
"""
# Split weights and bias if with_bias is True, otherwise use a zero bias
if with_bias:
weights, bias = weights_and_bias
else:
weights = weights_and_bias
bias = jnp.zeros(4)
# Vectorize single_loss computation over the batch of ground states and labels
costs = jax.vmap(single_loss, in_axes=[None, None, 0, 0, None, None, None])(
weights, bias, ground_states, labels, nqubits, loss_type, gate_id
)
# Return the average loss
return costs.mean()
#############################
######### Accuracy ##########
#############################
@partial(jax.jit, static_argnames=["nqubits", "gate_id"])
def single_pred(
weights: jnp.ndarray,
bias: jnp.ndarray,
ground_state: jnp.ndarray,
nqubits: int,
gate_id: str,
) -> int:
"""
Computes the prediction of the QCNN for a single ground state.
Parameters:
weights (jnp.ndarray): The weights of the QCNN.
bias (jnp.ndarray): The bias of the QCNN.
ground_state (jnp.ndarray): The ground state as a 1D numpy array.
nqubits (int): The number of qubits in the system.
gate_id (str): The identifier for the type of gate sequence to apply.
Returns:
int: The label corresponding to the maximum projector.
"""
projectors = variational_classifier(weights, bias, ground_state, nqubits, gate_id)
return np.argmax(projectors)
@partial(jax.jit, static_argnames=["nqubits", "gate_id"])
def pred(
weights: jnp.ndarray,
bias: jnp.ndarray,
ground_states: jnp.ndarray,
nqubits: int,
gate_id: str
) -> jnp.ndarray:
"""
Computes predictions for a batch of ground states using the QCNN model.
Parameters:
weights (jnp.ndarray): The weights of the QCNN.
bias (jnp.ndarray): The bias of the QCNN.
ground_states (jnp.ndarray): A batch of ground states as a 2D numpy array.
nqubits (int): The number of qubits in the system.
gate_id (str): The identifier for the type of gate sequence to apply.
Returns:
jnp.ndarray: An array of predicted labels for each ground state.
"""
# Vectorize single_pred computation over the batch of ground states
predictions = jax.vmap(single_pred, in_axes=[None, None, 0, None, None])(
weights, bias, ground_states, nqubits, gate_id
)
return predictions
@jax.jit
def acc(
predictions: jnp.ndarray,
labels: jnp.ndarray
) -> float:
"""
Computes the accuracy of the predictions with respect to the labels.
Parameters:
predictions (jnp.ndarray): An array of predicted labels.
labels (jnp.ndarray): An array of true labels.
Returns:
float: The accuracy of the predictions as a percentage.
"""
return (predictions == labels).sum() * 100 / len(labels)