Skip to content

Commit

Permalink
Fix param passthrough (#131)
Browse files Browse the repository at this point in the history
* fixes Gaussian state trainability

* dict instead of list for internal parameters

* tests parameter passthrough

* better handling of variable/constant params

* blacked

* modified changelog

* added two docstrings to make codefactor happier

* blacked

* Update .github/CHANGELOG.md

Co-authored-by: Sebastián Duque Mesa <[email protected]>

Co-authored-by: Sebastián Duque Mesa <[email protected]>
  • Loading branch information
ziofil and sduquemesa authored May 10, 2022
1 parent 30f8852 commit 8902144
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 31 deletions.
12 changes: 12 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
* States in Gaussian and Fock representation now can be concatenated.
[(#130)](https://github.com/XanaduAI/MrMustard/pull/130)

* Parameter passthrough allows to use custom parameters in the model, that is, objects accept correlated parameters. For example,
```python
from mrmustard.lab.gates import Sgate, BSgate

BS = BSgate(theta=np.pi/4, theta_trainable=True)[0,1]
S0 = Sgate(r=BS.theta)[0]
S1 = Sgate(r=-BS.theta)[1]

circ = S0 >> S1 >> BS
```
[(#131)](https://github.com/XanaduAI/MrMustard/pull/131)

```python
from mrmustard.lab.states import Gaussian, Fock'
from mrmustard.lab.gates import Attenuator
Expand Down
1 change: 1 addition & 0 deletions mrmustard/lab/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def __init__(
eigenvalues_trainable=eigenvalues_trainable,
symplectic_trainable=symplectic_trainable,
eigenvalues_bounds=eigenvalues_bounds,
symplectic_bounds=(None, None),
modes=modes,
normalize=normalize,
)
Expand Down
65 changes: 34 additions & 31 deletions mrmustard/utils/parametrized.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,36 @@ class Parametrized:
"""

def __init__(self, **kwargs): # NOTE: only kwargs so that we can use the arg names
self._trainable_parameters = []
self._constant_parameters = []
self._trainable_parameters = {"symplectic": [], "euclidean": [], "orthogonal": []}
self._constant_parameters = {"symplectic": [], "euclidean": [], "orthogonal": []}
self._param_names = []
owner = f"{self.__class__.__qualname__}"
for name, value in kwargs.items():
if math.from_backend(value):
if math.is_trainable(value):
self._trainable_parameters.append(value)
elif name + "_trainable" in kwargs and kwargs[name + "_trainable"]:
trainable = (name + "_trainable" in kwargs and kwargs[name + "_trainable"] is True) or (
math.from_backend(value) and math.is_trainable(value)
)
constant = (
name + "_trainable" in kwargs
and kwargs[name + "_trainable"] is False
or (math.from_backend(value) and not math.is_trainable(value))
)
param_type = (
"symplectic"
if name.startswith("symplectic")
else "orthogonal"
if name.startswith("orthogonal")
else "euclidean"
)
if trainable:
if not math.from_backend(value) or (
math.from_backend(value) and not math.is_trainable(value)
):
value = math.new_variable(value, kwargs[name + "_bounds"], owner + ":" + name)
self._trainable_parameters.append(value)
else:
self._constant_parameters.append(value)
elif name + "_trainable" in kwargs and kwargs[name + "_trainable"]:
value = math.new_variable(value, kwargs[name + "_bounds"], owner + ":" + name)
self._trainable_parameters.append(value)
elif name + "_trainable" in kwargs and not kwargs[name + "_trainable"]:
value = math.new_constant(value, owner + ":" + name)
self._constant_parameters.append(value)
self._trainable_parameters[param_type].append(value)
elif constant:
if not math.from_backend(value):
value = math.new_constant(value, owner + ":" + name)
self._constant_parameters[param_type].append(value)
else:
name = "_" + name
self.__dict__[name] = value
Expand All @@ -61,40 +72,32 @@ def trainable_parameters(self) -> Dict[str, List[Trainable]]:
if hasattr(self, "_ops"):
return {
"symplectic": math.unique_tensors(
[p for item in self._ops for p in item.trainable_parameters["symplectic"]]
[p for op in self._ops for p in op.trainable_parameters["symplectic"]]
),
"orthogonal": math.unique_tensors(
[p for item in self._ops for p in item.trainable_parameters["orthogonal"]]
[p for op in self._ops for p in op.trainable_parameters["orthogonal"]]
),
"euclidean": math.unique_tensors(
[p for item in self._ops for p in item.trainable_parameters["euclidean"]]
[p for op in self._ops for p in op.trainable_parameters["euclidean"]]
),
}

return {
"symplectic": [],
"orthogonal": [],
"euclidean": self._trainable_parameters,
} # default
return self._trainable_parameters # default

@property
def constant_parameters(self) -> Dict[str, List[Tensor]]:
r"""Returns the dictionary of constant parameters, searching recursively in the object tree (for example, when in a Circuit)."""
if hasattr(self, "_ops"):
return {
"symplectic": math.unique_tensors(
[p for item in self._ops for p in item.constant_parameters["symplectic"]]
[p for op in self._ops for p in op.constant_parameters["symplectic"]]
),
"orthogonal": math.unique_tensors(
[p for item in self._ops for p in item.constant_parameters["orthogonal"]]
[p for op in self._ops for p in op.constant_parameters["orthogonal"]]
),
"euclidean": math.unique_tensors(
[p for item in self._ops for p in item.constant_parameters["euclidean"]]
[p for op in self._ops for p in op.constant_parameters["euclidean"]]
),
}

return {
"symplectic": [],
"orthogonal": [],
"euclidean": self._constant_parameters,
} # default
return self._constant_parameters # default
2 changes: 2 additions & 0 deletions tests/test_lab/test_detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def test_homodyne_on_2mode_squeezed_vacuum(s, X):

@given(s=st.floats(1.0, 10.0), X=st.floats(-5.0, 5.0), angle=st.floats(0, np.pi))
def test_homodyne_on_2mode_squeezed_vacuum_with_angle(s, X, angle):
r"""Check that homodyne detection on TMSV works with an arbitrary quadrature angle"""
homodyne = Homodyne(quadrature_angle=angle, result=X)
r = homodyne.r
remaining_state = TMSV(r=np.arcsinh(np.sqrt(abs(s)))) << homodyne[0]
Expand Down Expand Up @@ -243,6 +244,7 @@ def test_homodyne_on_2mode_squeezed_vacuum_with_displacement(s, X, d):
def test_heterodyne_on_2mode_squeezed_vacuum_with_displacement(
s, x, y, d
): # TODO: check if this is correct
r"""Check that heterodyne detection on TMSV works with an arbitrary displacement"""
tmsv = TMSV(r=np.arcsinh(np.sqrt(s))) >> Dgate(x=d[:2], y=d[2:])
heterodyne = Heterodyne(modes=[0], x=x, y=y)
remaining_state = tmsv << heterodyne[0]
Expand Down
28 changes: 28 additions & 0 deletions tests/test_utils/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@
from mrmustard.lab.gates import Sgate, BSgate, S2gate, Ggate, Interferometer, Ggate
from mrmustard.lab.circuit import Circuit
from mrmustard.utils.training import Optimizer
from mrmustard.utils.parametrized import Parametrized
from mrmustard.lab.states import Vacuum
from mrmustard.physics.gaussian import trace, von_neumann_entropy
from mrmustard import settings

from mrmustard.math import Math

math = Math()


@given(n=st.integers(0, 3))
def test_S2gate_coincidence_prob(n):
Expand Down Expand Up @@ -230,6 +235,29 @@ def cost_fn():
assert np.allclose(np.sinh(circ.trainable_parameters["euclidean"][2]) ** 2, 1, atol=1e-2)


def test_parameter_passthrough():
"""Same as the test above, but with param passthrough"""
tf.random.set_seed(137)
r = np.arcsinh(1.0)
par = Parametrized(
r=math.new_variable(r, (0.0, None), "r"),
phi=math.new_variable(np.random.normal(), (None, None), "phi"),
)
ops = [
S2gate(r=r, phi=0.0, phi_trainable=True)[0, 1],
S2gate(r=r, phi=0.0, phi_trainable=True)[2, 3],
S2gate(r=par.r, phi=par.phi)[1, 2],
]
circ = Circuit(ops)

def cost_fn():
return tf.abs((Vacuum(4) >> circ).ket(cutoffs=[2, 2, 2, 2])[1, 1, 1, 1]) ** 2

opt = Optimizer(euclidean_lr=0.001)
opt.minimize(cost_fn, by_optimizing=[par], max_steps=300)
assert np.allclose(np.sinh(circ.trainable_parameters["euclidean"][2]) ** 2, 1, atol=1e-2)


def test_making_thermal_state_as_one_half_two_mode_squeezed_vacuum():
"""Optimizes a Ggate on two modes so as to prepare a state with the same entropy
and mean photon number as a thermal state"""
Expand Down

0 comments on commit 8902144

Please sign in to comment.