Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Channel from XY #514

Merged
merged 9 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion mrmustard/lab_dev/transformations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from mrmustard.physics.ansatz import PolyExpAnsatz, ArrayAnsatz
from mrmustard.physics.representations import Representation
from mrmustard.physics.wires import Wires
from mrmustard.utils.typing import ComplexMatrix
from mrmustard.utils.typing import ComplexMatrix, RealMatrix, Vector
from mrmustard.physics.triples import XY_to_channel_Abc
from mrmustard.physics.bargmann_utils import au2Symplectic, symplectic2Au, XY_of_channel
from ..circuit_components import CircuitComponent

Expand Down Expand Up @@ -446,6 +447,37 @@ def from_quadrature(
BB = QtoB_in >> QQ >> QtoB_out
return Channel.from_ansatz(modes_out, modes_in, BB.ansatz, name)

@classmethod
def from_XY(
cls,
modes_out: Sequence[int],
modes_in: Sequence[int],
X: RealMatrix,
Y: RealMatrix,
d: Vector | None = None,
) -> Channel:
r"""
Initialize a Channel from its XY representation.
Args:
modes: The modes the channel is defined on.
X: The X matrix of the channel.
Y: The Y matrix of the channel.
d: The d vector of the channel.

.. details::
Each Gaussian channel transforms a state with covarince matrix :math:`\Sigma` and mean :math:`\mu`
into a state with covariance matrix :math:`X \Sigma X^T + Y` and vector of means :math:`X\mu + d`.
This channel has a Bargmann triple that is computed in https://arxiv.org/pdf/2209.06069. We borrow
the formulas from the paper to implement the corresponding channel.
"""

if X.shape != (2 * len(modes_out), 2 * len(modes_in)):
raise ValueError(
f"The dimension of X matrix ({X.shape}) and number of modes ({len(modes_in), len(modes_out)}) don't match."
)

return Channel.from_bargmann(modes_out, modes_in, XY_to_channel_Abc(X, Y, d))

@classmethod
def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Channel:
r"""
Expand Down
63 changes: 62 additions & 1 deletion mrmustard/physics/triples.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import numpy as np

from mrmustard import math, settings
from mrmustard.utils.typing import Matrix, Vector, Scalar, RealMatrix
from mrmustard.utils.typing import Matrix, Vector, Scalar, RealMatrix, ComplexMatrix
from mrmustard.physics.gaussian_integrals import complex_gaussian_integral_2


Expand Down Expand Up @@ -778,3 +778,64 @@ def attenuator_kraus_Abc(eta: float) -> Union[Matrix, Vector, Scalar]:
b = _vacuum_B_vector(3)
c = 1.0 + 0j
return A, b, c


def XY_to_channel_Abc(X: RealMatrix, Y: RealMatrix, d: Vector | None = None) -> ComplexMatrix:
r"""
The method to compute the A matrix of a channel based on its X, Y, and d.
Args:
X: The X matrix of the channel
Y: The Y matrix of the channel
d: The d (displacement) vector of the channel -- if None, we consider it as 0
"""

m = Y.shape[-1] // 2
# considering no displacement if d is None
d = d if d else math.zeros(2 * m)

if X.shape != Y.shape:
raise ValueError(
"The dimension of X and Y matrices are not the same."
f"X.shape = {X.shape}, Y.shape = {Y.shape}"
)

xi = 1 / 2 * math.eye(2 * m, dtype=math.complex128) + 1 / 2 * X @ X.T + Y / settings.HBAR
xi_inv = math.inv(xi)
xi_inv_in_blocks = math.block(
[[math.eye(2 * m) - xi_inv, xi_inv @ X], [X.T @ xi_inv, math.eye(2 * m) - X.T @ xi_inv @ X]]
)
R = (
1
/ math.sqrt(complex(2))
* math.block(
[
[
math.eye(m, dtype=math.complex128),
1j * math.eye(m, dtype=math.complex128),
math.zeros((m, 2 * m), dtype=math.complex128),
],
[
math.zeros((m, 2 * m), dtype=math.complex128),
math.eye(m, dtype=math.complex128),
-1j * math.eye(m, dtype=math.complex128),
],
[
math.eye(m, dtype=math.complex128),
-1j * math.eye(m, dtype=math.complex128),
math.zeros((m, 2 * m), dtype=math.complex128),
],
[
math.zeros((m, 2 * m), dtype=math.complex128),
math.eye(m, dtype=math.complex128),
1j * math.eye(m, dtype=math.complex128),
],
]
)
)

A = math.Xmat(2 * m) @ R @ xi_inv_in_blocks @ math.conj(R).T
temp = math.block([[(xi_inv @ d).reshape(2 * m, 1)], [(-X.T @ xi_inv @ d).reshape((2 * m, 1))]])
b = 1 / math.sqrt(settings.HBAR) * math.conj(R) @ temp
c = math.exp(-0.5 / settings.HBAR * d @ xi_inv @ d) / math.sqrt(math.det(xi))

return A, b, c
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import pytest

from mrmustard import math
from mrmustard.lab_dev.circuit_components_utils import TraceOut
from mrmustard.lab_dev.circuit_components import CircuitComponent
from mrmustard.lab_dev.circuit_components_utils import TraceOut
from mrmustard.lab_dev.states import Coherent
Expand Down
7 changes: 4 additions & 3 deletions tests/test_lab_dev/test_states/test_dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement

from itertools import product

import numpy as np
import pytest

from mrmustard import math, settings
from mrmustard.lab_dev.circuit_components import CircuitComponent
from mrmustard.lab_dev.circuit_components_utils import TraceOut
from mrmustard.physics.gaussian import vacuum_cov
from mrmustard.lab_dev.states import Coherent, DM, Ket, Number, Vacuum
from mrmustard.lab_dev.states import DM, Coherent, Ket, Number, Vacuum
from mrmustard.lab_dev.transformations import Attenuator, Dgate
from mrmustard.physics.wires import Wires
from mrmustard.physics.gaussian import vacuum_cov
from mrmustard.physics.representations import Representation
from mrmustard.physics.wires import Wires


def coherent_state_quad(q, x, y, phi=0):
Expand Down
19 changes: 6 additions & 13 deletions tests/test_lab_dev/test_states/test_ket.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,21 @@
# pylint: disable=unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement

from itertools import product

import numpy as np
import pytest

from ipywidgets import Box, HBox, VBox, HTML
from ipywidgets import HTML, Box, HBox, VBox
from plotly.graph_objs import FigureWidget

from mrmustard import math, settings
from mrmustard.lab_dev.circuit_components import CircuitComponent
from mrmustard.math.parameters import Constant, Variable
from mrmustard.physics.gaussian import vacuum_cov, vacuum_means, squeezed_vacuum_cov
from mrmustard.physics.triples import coherent_state_Abc
from mrmustard.lab_dev.circuit_components_utils import TraceOut
from mrmustard.lab_dev.states import (
Coherent,
DisplacedSqueezed,
DM,
Ket,
Number,
Vacuum,
)
from mrmustard.lab_dev.states import DM, Coherent, DisplacedSqueezed, Ket, Number, Vacuum
from mrmustard.lab_dev.transformations import Attenuator, Dgate, Sgate
from mrmustard.math.parameters import Constant, Variable
from mrmustard.physics.gaussian import squeezed_vacuum_cov, vacuum_cov, vacuum_means
from mrmustard.physics.representations import Representation
from mrmustard.physics.triples import coherent_state_Abc
from mrmustard.physics.wires import Wires
from mrmustard.widgets import state as state_widget

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,11 @@ def test_XY(self):

X, Y = Attenuator([0], 0.2).XY
assert np.allclose(X, np.sqrt(0.2) * np.eye(2)) and np.allclose(Y, 0.4 * np.eye(2))

@pytest.mark.parametrize("nmodes", [1, 2, 3])
def test_from_XY(self, nmodes):
X = np.random.random((2 * nmodes, 2 * nmodes))
Y = np.random.random((2 * nmodes, 2 * nmodes))
x, y = Channel.from_XY(range(nmodes), range(nmodes), X, Y).XY
assert math.allclose(x, X)
assert math.allclose(y, Y)
4 changes: 2 additions & 2 deletions tests/test_physics/test_ansatz/test_array_ansatz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from unittest.mock import patch

import numpy as np
from ipywidgets import HBox, VBox, HTML, Tab
from plotly.graph_objs import FigureWidget
import pytest
from ipywidgets import HTML, HBox, Tab, VBox
from plotly.graph_objs import FigureWidget

from mrmustard import math
from mrmustard.physics.ansatz.array_ansatz import ArrayAnsatz
Expand Down
8 changes: 4 additions & 4 deletions tests/test_physics/test_ansatz/test_polyexp_ansatz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@
from unittest.mock import patch

import numpy as np
from ipywidgets import Box, VBox, HTML, IntText, Stack, IntSlider
from plotly.graph_objs import FigureWidget
import pytest
from ipywidgets import HTML, Box, IntSlider, IntText, Stack, VBox
from plotly.graph_objs import FigureWidget

from mrmustard import math
from mrmustard.physics.ansatz.array_ansatz import ArrayAnsatz
from mrmustard.physics.ansatz.polyexp_ansatz import PolyExpAnsatz
from mrmustard.physics.gaussian_integrals import (
complex_gaussian_integral_1,
complex_gaussian_integral_2,
)
from mrmustard.physics.ansatz.polyexp_ansatz import PolyExpAnsatz
from mrmustard.physics.ansatz.array_ansatz import ArrayAnsatz

from ...random import Abc_triple

Expand Down
7 changes: 3 additions & 4 deletions tests/test_physics/test_representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
import pytest

from mrmustard import math

from mrmustard.physics.representations import Representation, RepEnum
from mrmustard.physics.wires import Wires
from mrmustard.physics.ansatz import ArrayAnsatz, PolyExpAnsatz
from mrmustard.physics.triples import displacement_gate_Abc, bargmann_to_quadrature_Abc
from mrmustard.physics.representations import RepEnum, Representation
from mrmustard.physics.triples import bargmann_to_quadrature_Abc, displacement_gate_Abc
from mrmustard.physics.wires import Wires

from ..random import Abc_triple

Expand Down
23 changes: 22 additions & 1 deletion tests/test_physics/test_triples.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import pytest

from mrmustard import math
from mrmustard import math, settings
from mrmustard.physics import triples
from mrmustard.physics.ansatz import PolyExpAnsatz

Expand Down Expand Up @@ -354,3 +354,24 @@ def test_gaussian_random_noise_Abc(self):
assert math.allclose(A, A_by_hand)
assert math.allclose(b, b_by_hand)
assert math.allclose(c, c_by_hand)

def test_XY_to_channel_Abc(self):

# Creating an attenuator object and testing its Abc triple
eta = np.random.random()
X = np.sqrt(eta) * np.eye(2)
Y = settings.HBAR / 2 * (1 - eta) * np.eye(2)

A, b, c = triples.XY_to_channel_Abc(X, Y)

A_by_hand = np.block(
[
[0, np.sqrt(eta), 0, 0],
[np.sqrt(eta), 0, 0, 1 - eta],
[0, 0, 0, np.sqrt(eta)],
[0, 1 - eta, np.sqrt(eta), 0],
]
)
assert np.allclose(A, A_by_hand)
assert np.allclose(b, np.zeros((4, 1)))
assert np.isclose(c, 1.0)
Loading