From 6d70fecd396a17dcc348b4c2cc9dc7b2775ed97b Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Sun, 28 Jan 2024 14:33:27 -0800 Subject: [PATCH 01/24] Create wires.py Part of MM 1.0 refactor --- mrmustard/lab/utils/wires.py | 321 +++++++++++++++++++++++++++++++++++ 1 file changed, 321 insertions(+) create mode 100644 mrmustard/lab/utils/wires.py diff --git a/mrmustard/lab/utils/wires.py b/mrmustard/lab/utils/wires.py new file mode 100644 index 000000000..7fa20bb3e --- /dev/null +++ b/mrmustard/lab/utils/wires.py @@ -0,0 +1,321 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""`Wires` class for handling the connectivity of an object in a circuit.""" + +from __future__ import annotations + +from typing import Iterable, Optional +import numpy as np +from mrmustard import settings + +# pylint: disable=protected-access + + +class Wires: + r"""`Wires` class for handling the connectivity of an object in a circuit. + In MrMustard, ``CircuitComponent``s have a ``Wires`` object as attribute + to handle the wires of the component and to connect components together. + + Wires are arranged into four groups, and each of the groups can + span multiple modes: + _______________ + input bra modes --->| circuit |---> output bra modes + input ket modes --->| component |---> output ket modes + --------------- + Each of the four groups can be empty. In particular, the wires of a state have + no inputs and the wires of a measurement have no outputs. Similarly, + a standard unitary transformation has no bra wires. + + We refer to these four groups in a "standard order": + 0. output_bra for all modes + 1. input_bra for all modes + 2. output_ket for all modes + 3. input_ket for all modes + + A ``Wires`` object can return subsets (views) of itself. Available subsets are: + - input/output (wires on input/output side) + - bra/ket (wires on bra/ket side) + - modes (wires on the given modes) + + For example, ``wires.input`` returns a ``Wires`` object with only the input wires + (on bra and ket sides and on all the modes). + Note that these can be combined together: ``wires.input.bra[(1,2)] returns a + ``Wires`` object with only the input bra wires on modes 1 and 2. + Note these are views of the original ``Wires`` object, i.e. we can set the ``ids`` on the + views and it will be set on the original, e.g. this is a valid way to connect two sets + of wires by setting their ids to be equal: ``wires1.output.ids = wires2.input.ids``. + + A very useful feature of the ``Wires`` class is the support for the right shift + operator ``>>``. This allows us to connect two ``Wires`` objects together as + ``wires1 >> wires2``. This will return the ``Wires`` object of the two + wires objects connected together as if they were in a circuit: + ____________ ____________ + in bra --->| wires1 |---> out bra ---> in bra --->| wires2 |---> out bra + in ket --->| |---> out ket ---> in ket --->| |---> out ket + ------------ ------------ + The returned ``Wires`` object will contain the surviving wires of the two + ``Wires`` objects and it will raise an error if there are overlaps between + the surviving wires. This is especially useful for handling the ordering of the + wires when connecting components together: we are always guaranteed that a + ``Wires`` object will provide the wires in the standard order. + + Args: + modes_out_bra (Iterable[int]): The output modes on the bra side. + modes_in_bra (Iterable[int]): The input modes on the bra side. + modes_out_ket (Iterable[int]): The output modes on the ket side. + modes_in_ket (Iterable[int]): The input modes on the ket side. + + Note that the order of the modes passed to initialize the object doesn't matter, + as they get sorted at init time. + """ + + def __init__( + self, + modes_out_bra: Optional[Iterable[int]] = None, + modes_in_bra: Optional[Iterable[int]] = None, + modes_out_ket: Optional[Iterable[int]] = None, + modes_in_ket: Optional[Iterable[int]] = None, + ) -> None: + modes_out_bra = modes_out_bra or [] + modes_in_bra = modes_in_bra or [] + modes_out_ket = modes_out_ket or [] + modes_in_ket = modes_in_ket or [] + + self._modes = sorted( + set(modes_out_bra) | set(modes_in_bra) | set(modes_out_ket) | set(modes_in_ket) + ) + randint = settings.rng.integers # MM random number generator + outbra = {m: randint(1, 2**62) if m in modes_out_bra else 0 for m in self._modes} + inbra = {m: randint(1, 2**62) if m in modes_in_bra else 0 for m in self._modes} + outket = {m: randint(1, 2**62) if m in modes_out_ket else 0 for m in self._modes} + inket = {m: randint(1, 2**62) if m in modes_in_ket else 0 for m in self._modes} + self._id_array = np.array([[outbra[m], inbra[m], outket[m], inket[m]] for m in self._modes]) + self._mask = np.ones_like(self._id_array) # multiplicative mask + + def _args(self): + r"""Returns the same args one needs to initialize this object.""" + ob_modes = np.array(self._modes)[self._id_array[:, 0] > 0].tolist() + ib_modes = np.array(self._modes)[self._id_array[:, 1] > 0].tolist() + ok_modes = np.array(self._modes)[self._id_array[:, 2] > 0].tolist() + ik_modes = np.array(self._modes)[self._id_array[:, 3] > 0].tolist() + return tuple(ob_modes), tuple(ib_modes), tuple(ok_modes), tuple(ik_modes) + + @classmethod + def _from_data(cls, id_array: np.ndarray, modes: list[int], mask=None): + r"""Private class method to initialize Wires object from the given data.""" + w = cls() + w._id_array = id_array + w._modes = modes + w._mask = mask if mask is not None else np.ones_like(w._id_array) + return w + + def _view(self, masked_rows: tuple[int, ...] = (), masked_cols: tuple[int, ...] = ()) -> Wires: + r"""A masked view of this Wires object.""" + w = self._from_data(self._id_array, self._modes, self._mask.copy()) + w._mask[masked_rows, :] = -1 + w._mask[:, masked_cols] = -1 + return w + + def _mode(self, mode: int) -> np.ndarray: + "A slice of the id_array matrix at the given mode." + return np.maximum(0, self.id_array[[self._modes.index(mode)]])[0] + + @property + def id_array(self) -> np.ndarray: + "The id_array of the available wires in the standard order (bra/ket x out/in x mode)." + return self._id_array * self._mask + + @property + def ids(self) -> list[int]: + "The list of ids of the available wires in the standard order." + flat = self.id_array.T.ravel() + return flat[flat > 0].tolist() + + @ids.setter + def ids(self, ids: list[int]): + "Sets the ids of the available wires." + if len(ids) != len(self.ids): + raise ValueError(f"wrong number of ids (expected {len(self.ids)}, got {len(ids)})") + self._id_array.flat[self.id_array.flatten() > 0] = ids + + @property + def modes(self) -> list[int]: + "The set of modes spanned by the populated wires." + return [m for m in self._modes if any(self.id_array[self._modes.index(m)] > 0)] + + @property + def indices(self) -> list[int]: + r"""Returns the array of indices of this subset in the standard order. + (bra/ket x out/in x mode). Use this to get the indices for bargmann contractions. + """ + flat = self.id_array.T.ravel() + flat = flat[flat != 0] + return np.where(flat > 0)[0].tolist() + + @property + def input(self) -> Wires: + "A view of self without output wires" + return self._view(masked_cols=(0, 2)) + + @property + def output(self) -> Wires: + "A view of self without input wires" + return self._view(masked_cols=(1, 3)) + + @property + def ket(self) -> Wires: + "A view of self without bra wires" + return self._view(masked_cols=(0, 1)) + + @property + def bra(self) -> Wires: + "A view of self without ket wires" + return self._view(masked_cols=(2, 3)) + + @property + def adjoint(self) -> Wires: + r""" + The adjoint (ket <-> bra) of this wires object. + """ + return self._from_data(self._id_array[:, [2, 3, 0, 1]], self._modes, self._mask) + + @property + def dual(self) -> Wires: + r""" + The dual (in <-> out) of this wires object. + """ + return self._from_data(self._id_array[:, [1, 0, 3, 2]], self._modes, self._mask) + + def copy(self) -> Wires: + r"""A copy of this Wires object with new ids.""" + w = Wires(*self._args()) + w._mask = self._mask.copy() + return w + + def __bool__(self) -> bool: + return True if len(self.ids) > 0 else False + + def __getitem__(self, modes: Iterable[int] | int) -> Wires: + "A view of this Wires object with wires only on the given modes." + modes = [modes] if isinstance(modes, int) else modes + idxs = tuple(list(self._modes).index(m) for m in set(self._modes).difference(modes)) + return self._view(masked_rows=idxs) + + def __lshift__(self, other: Wires) -> Wires: + return (other.dual >> self.dual).dual # how cool is this + + def __rshift__(self, other: Wires) -> Wires: + r"""Returns a new Wires object with the wires of self and other combined as two + components in a circuit where the output of self connects to the input of other: + ``self >> other`` + All surviving wires are arranged in the standard order. + A ValueError is raised if there are any surviving wires that overlap, which is the only + possible way two objects aren't compatible in a circuit.""" + all_modes = sorted(set(self.modes) | set(other.modes)) + new_id_array = np.zeros((len(all_modes), 4), dtype=np.int64) + for i, m in enumerate(all_modes): + if m in self.modes and m in other.modes: + # m-th row of self and other (and self output bra = sob, etc...) + sob, sib, sok, sik = self._mode(m) + oob, oib, ook, oik = other._mode(m) + errors = { + "output bra": sob and oob and not oib, + "output ket": sok and ook and not oik, + "input bra": oib and sib and not sob, + "input ket": oik and sik and not sok, + } + if any(errors.values()): + position = [k for k, v in errors.items() if v][0] + raise ValueError(f"{position} wire overlap at mode {m}") + if bool(sob) == bool(oib): # if the inner wires are both there or both not there + new_id_array[i] += np.array([oob, sib, 0, 0]) + elif not sib and not sob: # no wires on self + new_id_array[i] += np.array([oob, oib, 0, 0]) + else: # no wires on other + new_id_array[i] += np.array([sob, sib, 0, 0]) + if bool(sok) == bool(oik): # same as above but on the ket side + new_id_array[i] += np.array([0, 0, ook, sik]) + elif not sik and not sok: + new_id_array[i] += np.array([0, 0, ook, oik]) + else: + new_id_array[i] += np.array([0, 0, sok, sik]) + elif m in self.modes and not m in other.modes: + new_id_array[i] += self._mode(m) + elif m in other.modes and not m in self.modes: + new_id_array[i] += other._mode(m) + return self._from_data(np.abs(new_id_array), all_modes) + + def __repr__(self) -> str: + ob_modes, ib_modes, ok_modes, ik_modes = self._args() + return f"Wires({ob_modes}, {ib_modes}, {ok_modes}, {ik_modes})" + + def _repr_html_(self): # pragma: no cover + "A matrix plot of the id_array." + row_labels = map(str, self._modes) + col_labels = ["bra-out", "bra-in", "ket-out", "ket-in"] + array = np.abs(self.id_array) / (self.id_array + 1e-15) + idxs = (i for i in self.indices) + box_size = "60px" # Set the size of the squares + html = '' + # colors + active = "#5b9bd5" + inactive = "#d6e8f7" + + # Add column headers + html += "" + for label in [""] + col_labels: # Add an empty string for the top-left cell + html += f'' + html += "" + + # Initialize rows with row labels + rows_html = [ + f'' + for label in row_labels + ] + + # Add table cells (column by column) + for label, col in zip(col_labels, array.T): + for row_idx, value in enumerate(col): + color = ( + "white" + if np.isclose(value, 0) + else (active if np.isclose(value, 1) else inactive) + ) + cell_html = f'' + ) + else: + cell_html += '">' + rows_html[row_idx] += cell_html + + # Close the rows and add them to the HTML table + for row_html in rows_html: + row_html += "" + html += row_html + + html += "
{label}
{label}{str(next(idxs))}
" + + try: + from IPython.core.display import display, HTML # pylint: disable=import-outside-toplevel + + display(HTML(html)) + except ImportError as e: + raise ImportError( + "To display the wires in a jupyter notebook you need to `pip install IPython` first." + ) from e From 544cd740b985516b2993ed7e7f1d4b83f902e9b5 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Sun, 28 Jan 2024 14:36:18 -0800 Subject: [PATCH 02/24] adds wires tests --- tests/test_lab/test_wires.py | 96 ++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 tests/test_lab/test_wires.py diff --git a/tests/test_lab/test_wires.py b/tests/test_lab/test_wires.py new file mode 100644 index 000000000..790b2e2fc --- /dev/null +++ b/tests/test_lab/test_wires.py @@ -0,0 +1,96 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Wires class.""" + +# pylint: disable=missing-function-docstring +# pylint: disable=protected-access + +import pytest +from mrmustard.lab_dev.wires import Wires + + +def test_wires_view_has_same_ids(): + w = Wires([0], [0], [0], [0]) + assert set(w.ids) == set(w._view().ids) + + +def test_view_can_edit_original(): + w = Wires([0], [0], [0], [0]) + w._view().ids = [9, 99, 999, 9999] + assert w.ids == [9, 99, 999, 9999] + + +def test_wire_subsets(): + w = Wires([0], [1], [2], [3]) + assert w.output.bra.modes == [0] + assert w.input.bra.modes == [1] + assert w.output.ket.modes == [2] + assert w.input.ket.modes == [3] + + +def test_wire_mode_subsets(): + w = Wires([10], [11], [12], [13]) + assert w[10].ids == w.output.bra.ids + assert w[11].ids == w.input.bra.ids + assert w[12].ids == w.output.ket.ids + assert w[13].ids == w.input.ket.ids + + +def test_indices(): + w = Wires([0, 1, 2], [3, 4, 5], [6, 7], [8]) + assert w.output.indices == [0, 1, 2, 6, 7] + assert w.bra.indices == [0, 1, 2, 3, 4, 5] + assert w.input.indices == [3, 4, 5, 8] + assert w.ket.indices == [6, 7, 8] + + +def test_setting_ids(): + w = Wires([0], [0], [0], [0]) + w.ids = [9, 99, 999, 9999] + assert w.ids == [9, 99, 999, 9999] + + +def test_add_wires(): + w1 = Wires([0], [1], [2], [3]) + w2 = Wires([1], [2], [3], [4]) + w12 = Wires([0, 1], [1, 2], [2, 3], [3, 4]) + assert (w1 + w2).modes == w12.modes + + +def test_cant_add_overlapping_wires(): + w1 = Wires([0], [1], [2], [3]) + w2 = Wires([0], [2], [3], [4]) + with pytest.raises(Exception): + w1 + w2 + + +def test_args(): + w = Wires([0], [1], [2], [3]) + assert w._args() == ((0,), (1,), (2,), (3,)) + + +def test_right_shift_general_contraction(): + # contracts 1,1 on bra side + # contracts 3,3 and 13,13 on ket side (note order doesn't matter) + u = Wires([1, 5], [2, 6, 15], [3, 7, 13], [4, 8]) + v = Wires([0, 9, 14], [1, 10], [2, 11], [13, 3, 12]) + assert (u >> v)._args() == ((0, 5, 9, 14), (2, 6, 10, 15), (2, 7, 11), (4, 8, 12)) + + +def test_error_if_cant_contract(): + u = Wires([], [], [0], []) # only output wire + v = Wires([], [], [0], []) # only output wire + with pytest.raises(ValueError): + u >> v From 1a0eee89eb44f0a000016c1bc1c486b0e4256186 Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Sun, 28 Jan 2024 14:40:08 -0800 Subject: [PATCH 03/24] blacked --- mrmustard/lab/abstract/measurement.py | 6 ++---- mrmustard/lab/gates.py | 4 +--- mrmustard/lab/utils/wires.py | 9 ++++++--- mrmustard/training/callbacks.py | 6 +++--- mrmustard/utils/typing.py | 3 +-- tests/test_math/test_backend_manager.py | 1 + tests/test_math/test_compactFock.py | 1 + 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mrmustard/lab/abstract/measurement.py b/mrmustard/lab/abstract/measurement.py index b0f017281..5f3d87bfd 100644 --- a/mrmustard/lab/abstract/measurement.py +++ b/mrmustard/lab/abstract/measurement.py @@ -87,12 +87,10 @@ def outcome(self): ... @abstractmethod - def _measure_fock(self, other: State) -> Union[State, float]: - ... + def _measure_fock(self, other: State) -> Union[State, float]: ... @abstractmethod - def _measure_gaussian(self, other: State) -> Union[State, float]: - ... + def _measure_gaussian(self, other: State) -> Union[State, float]: ... def primal(self, other: State) -> Union[State, float]: """performs the measurement procedure according to the representation of the incoming state""" diff --git a/mrmustard/lab/gates.py b/mrmustard/lab/gates.py index aadba4bde..374cf39df 100644 --- a/mrmustard/lab/gates.py +++ b/mrmustard/lab/gates.py @@ -1041,9 +1041,7 @@ def primal(self, state): coeff = math.cast( math.exp( - -0.5 - * self.phase_stdev.value**2 - * math.arange(-dm.shape[-2] + 1, dm.shape[-1]) ** 2 + -0.5 * self.phase_stdev.value**2 * math.arange(-dm.shape[-2] + 1, dm.shape[-1]) ** 2 ), dm.dtype, ) diff --git a/mrmustard/lab/utils/wires.py b/mrmustard/lab/utils/wires.py index 7fa20bb3e..48b49b12b 100644 --- a/mrmustard/lab/utils/wires.py +++ b/mrmustard/lab/utils/wires.py @@ -241,9 +241,9 @@ def __rshift__(self, other: Wires) -> Wires: raise ValueError(f"{position} wire overlap at mode {m}") if bool(sob) == bool(oib): # if the inner wires are both there or both not there new_id_array[i] += np.array([oob, sib, 0, 0]) - elif not sib and not sob: # no wires on self + elif not sib and not sob: # no wires on self new_id_array[i] += np.array([oob, oib, 0, 0]) - else: # no wires on other + else: # no wires on other new_id_array[i] += np.array([sob, sib, 0, 0]) if bool(sok) == bool(oik): # same as above but on the ket side new_id_array[i] += np.array([0, 0, ook, sik]) @@ -312,7 +312,10 @@ def _repr_html_(self): # pragma: no cover html += "" try: - from IPython.core.display import display, HTML # pylint: disable=import-outside-toplevel + from IPython.core.display import ( + display, + HTML, + ) # pylint: disable=import-outside-toplevel display(HTML(html)) except ImportError as e: diff --git a/mrmustard/training/callbacks.py b/mrmustard/training/callbacks.py index a343fd94d..a5e8a7874 100644 --- a/mrmustard/training/callbacks.py +++ b/mrmustard/training/callbacks.py @@ -254,9 +254,9 @@ def call( orig_cost = np.array(optimizer.callback_history["orig_cost"][-1]).item() obj_scalars[f"{obj_tag}/orig_cost"] = orig_cost if self.cost_converter is not None: - obj_scalars[ - f"{obj_tag}/{self.cost_converter.__name__}(orig_cost)" - ] = self.cost_converter(orig_cost) + obj_scalars[f"{obj_tag}/{self.cost_converter.__name__}(orig_cost)"] = ( + self.cost_converter(orig_cost) + ) for k, v in obj_scalars.items(): tf.summary.scalar(k, data=v, step=self.optimizer_step) diff --git a/mrmustard/utils/typing.py b/mrmustard/utils/typing.py index 36cf64bd9..84a157784 100644 --- a/mrmustard/utils/typing.py +++ b/mrmustard/utils/typing.py @@ -95,5 +95,4 @@ class Batch(Protocol[T_co]): r"""Anything that can iterate over objects of type T_co.""" - def __iter__(self) -> Iterator[T_co]: - ... + def __iter__(self) -> Iterator[T_co]: ... diff --git a/tests/test_math/test_backend_manager.py b/tests/test_math/test_backend_manager.py index 963ca5cc9..64935f3c0 100644 --- a/tests/test_math/test_backend_manager.py +++ b/tests/test_math/test_backend_manager.py @@ -29,6 +29,7 @@ class TestBackendManager: r""" Tests the BackendManager. """ + l1 = [1.0] l2 = [1.0 + 0.0j, -2.0 + 2.0j] l3 = [[1.0, 2.0], [-3.0, 4.0]] diff --git a/tests/test_math/test_compactFock.py b/tests/test_math/test_compactFock.py index dd51e51fe..48761e5ab 100644 --- a/tests/test_math/test_compactFock.py +++ b/tests/test_math/test_compactFock.py @@ -1,6 +1,7 @@ """ Unit tests for mrmustard.math.compactFock.compactFock~ """ + import importlib import numpy as np From cd088c1fb3ddc47d845797621f005d8e43d4ad7c Mon Sep 17 00:00:00 2001 From: Filippo Miatto Date: Sun, 28 Jan 2024 15:00:18 -0800 Subject: [PATCH 04/24] fixed import --- tests/test_lab/test_wires.py | 2 +- tests/test_lab_dev/test_wires.py | 96 ++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 tests/test_lab_dev/test_wires.py diff --git a/tests/test_lab/test_wires.py b/tests/test_lab/test_wires.py index 790b2e2fc..a47d0198d 100644 --- a/tests/test_lab/test_wires.py +++ b/tests/test_lab/test_wires.py @@ -18,7 +18,7 @@ # pylint: disable=protected-access import pytest -from mrmustard.lab_dev.wires import Wires +from mrmustard.lab.wires import Wires def test_wires_view_has_same_ids(): diff --git a/tests/test_lab_dev/test_wires.py b/tests/test_lab_dev/test_wires.py new file mode 100644 index 000000000..790b2e2fc --- /dev/null +++ b/tests/test_lab_dev/test_wires.py @@ -0,0 +1,96 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Wires class.""" + +# pylint: disable=missing-function-docstring +# pylint: disable=protected-access + +import pytest +from mrmustard.lab_dev.wires import Wires + + +def test_wires_view_has_same_ids(): + w = Wires([0], [0], [0], [0]) + assert set(w.ids) == set(w._view().ids) + + +def test_view_can_edit_original(): + w = Wires([0], [0], [0], [0]) + w._view().ids = [9, 99, 999, 9999] + assert w.ids == [9, 99, 999, 9999] + + +def test_wire_subsets(): + w = Wires([0], [1], [2], [3]) + assert w.output.bra.modes == [0] + assert w.input.bra.modes == [1] + assert w.output.ket.modes == [2] + assert w.input.ket.modes == [3] + + +def test_wire_mode_subsets(): + w = Wires([10], [11], [12], [13]) + assert w[10].ids == w.output.bra.ids + assert w[11].ids == w.input.bra.ids + assert w[12].ids == w.output.ket.ids + assert w[13].ids == w.input.ket.ids + + +def test_indices(): + w = Wires([0, 1, 2], [3, 4, 5], [6, 7], [8]) + assert w.output.indices == [0, 1, 2, 6, 7] + assert w.bra.indices == [0, 1, 2, 3, 4, 5] + assert w.input.indices == [3, 4, 5, 8] + assert w.ket.indices == [6, 7, 8] + + +def test_setting_ids(): + w = Wires([0], [0], [0], [0]) + w.ids = [9, 99, 999, 9999] + assert w.ids == [9, 99, 999, 9999] + + +def test_add_wires(): + w1 = Wires([0], [1], [2], [3]) + w2 = Wires([1], [2], [3], [4]) + w12 = Wires([0, 1], [1, 2], [2, 3], [3, 4]) + assert (w1 + w2).modes == w12.modes + + +def test_cant_add_overlapping_wires(): + w1 = Wires([0], [1], [2], [3]) + w2 = Wires([0], [2], [3], [4]) + with pytest.raises(Exception): + w1 + w2 + + +def test_args(): + w = Wires([0], [1], [2], [3]) + assert w._args() == ((0,), (1,), (2,), (3,)) + + +def test_right_shift_general_contraction(): + # contracts 1,1 on bra side + # contracts 3,3 and 13,13 on ket side (note order doesn't matter) + u = Wires([1, 5], [2, 6, 15], [3, 7, 13], [4, 8]) + v = Wires([0, 9, 14], [1, 10], [2, 11], [13, 3, 12]) + assert (u >> v)._args() == ((0, 5, 9, 14), (2, 6, 10, 15), (2, 7, 11), (4, 8, 12)) + + +def test_error_if_cant_contract(): + u = Wires([], [], [0], []) # only output wire + v = Wires([], [], [0], []) # only output wire + with pytest.raises(ValueError): + u >> v From 486fcc6a806f907af155b975b532837da71e0dd4 Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 07:54:28 +0000 Subject: [PATCH 05/24] moved files --- mrmustard/lab/{utils => }/wires.py | 0 tests/test_lab/test_wires.py | 8 +-- tests/test_lab_dev/test_wires.py | 96 ------------------------------ 3 files changed, 4 insertions(+), 100 deletions(-) rename mrmustard/lab/{utils => }/wires.py (100%) delete mode 100644 tests/test_lab_dev/test_wires.py diff --git a/mrmustard/lab/utils/wires.py b/mrmustard/lab/wires.py similarity index 100% rename from mrmustard/lab/utils/wires.py rename to mrmustard/lab/wires.py diff --git a/tests/test_lab/test_wires.py b/tests/test_lab/test_wires.py index a47d0198d..6b123b587 100644 --- a/tests/test_lab/test_wires.py +++ b/tests/test_lab/test_wires.py @@ -62,18 +62,18 @@ def test_setting_ids(): assert w.ids == [9, 99, 999, 9999] -def test_add_wires(): +def test_non_overlapping_wires(): w1 = Wires([0], [1], [2], [3]) w2 = Wires([1], [2], [3], [4]) w12 = Wires([0, 1], [1, 2], [2, 3], [3, 4]) - assert (w1 + w2).modes == w12.modes + assert (w1 >> w2).modes == w12.modes def test_cant_add_overlapping_wires(): w1 = Wires([0], [1], [2], [3]) w2 = Wires([0], [2], [3], [4]) - with pytest.raises(Exception): - w1 + w2 + with pytest.raises(ValueError): + w1 >> w2 def test_args(): diff --git a/tests/test_lab_dev/test_wires.py b/tests/test_lab_dev/test_wires.py deleted file mode 100644 index 790b2e2fc..000000000 --- a/tests/test_lab_dev/test_wires.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Wires class.""" - -# pylint: disable=missing-function-docstring -# pylint: disable=protected-access - -import pytest -from mrmustard.lab_dev.wires import Wires - - -def test_wires_view_has_same_ids(): - w = Wires([0], [0], [0], [0]) - assert set(w.ids) == set(w._view().ids) - - -def test_view_can_edit_original(): - w = Wires([0], [0], [0], [0]) - w._view().ids = [9, 99, 999, 9999] - assert w.ids == [9, 99, 999, 9999] - - -def test_wire_subsets(): - w = Wires([0], [1], [2], [3]) - assert w.output.bra.modes == [0] - assert w.input.bra.modes == [1] - assert w.output.ket.modes == [2] - assert w.input.ket.modes == [3] - - -def test_wire_mode_subsets(): - w = Wires([10], [11], [12], [13]) - assert w[10].ids == w.output.bra.ids - assert w[11].ids == w.input.bra.ids - assert w[12].ids == w.output.ket.ids - assert w[13].ids == w.input.ket.ids - - -def test_indices(): - w = Wires([0, 1, 2], [3, 4, 5], [6, 7], [8]) - assert w.output.indices == [0, 1, 2, 6, 7] - assert w.bra.indices == [0, 1, 2, 3, 4, 5] - assert w.input.indices == [3, 4, 5, 8] - assert w.ket.indices == [6, 7, 8] - - -def test_setting_ids(): - w = Wires([0], [0], [0], [0]) - w.ids = [9, 99, 999, 9999] - assert w.ids == [9, 99, 999, 9999] - - -def test_add_wires(): - w1 = Wires([0], [1], [2], [3]) - w2 = Wires([1], [2], [3], [4]) - w12 = Wires([0, 1], [1, 2], [2, 3], [3, 4]) - assert (w1 + w2).modes == w12.modes - - -def test_cant_add_overlapping_wires(): - w1 = Wires([0], [1], [2], [3]) - w2 = Wires([0], [2], [3], [4]) - with pytest.raises(Exception): - w1 + w2 - - -def test_args(): - w = Wires([0], [1], [2], [3]) - assert w._args() == ((0,), (1,), (2,), (3,)) - - -def test_right_shift_general_contraction(): - # contracts 1,1 on bra side - # contracts 3,3 and 13,13 on ket side (note order doesn't matter) - u = Wires([1, 5], [2, 6, 15], [3, 7, 13], [4, 8]) - v = Wires([0, 9, 14], [1, 10], [2, 11], [13, 3, 12]) - assert (u >> v)._args() == ((0, 5, 9, 14), (2, 6, 10, 15), (2, 7, 11), (4, 8, 12)) - - -def test_error_if_cant_contract(): - u = Wires([], [], [0], []) # only output wire - v = Wires([], [], [0], []) # only output wire - with pytest.raises(ValueError): - u >> v From a013d1d2a3760e75869e75eee9e1556b5c6530a5 Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 08:16:08 +0000 Subject: [PATCH 06/24] updated changelog --- .github/CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 2458d4c80..99437fa3f 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -15,12 +15,15 @@ * Added an Ansatz abstract class and PolyExpAnsatz concrete implementation. This is used in the Bargmann representation. [(#295)](https://github.com/XanaduAI/MrMustard/pull/295) -* Added `complex_gaussian_integral` and `real_gaussian_integral` methods. +* Added `complex_gaussian_integral` method. [(#295)](https://github.com/XanaduAI/MrMustard/pull/295) * Added `Bargmann` representation (parametrized by Abc). Supports all algebraic operations and CV (exact) inner product. [(#296)](https://github.com/XanaduAI/MrMustard/pull/296) +* Added a new class `Wires` in `mrmustard.lab` to handle the connectivity of objects in a circuit. + [(#330)](https://github.com/XanaduAI/MrMustard/pull/330) + ### Breaking changes * Removed circular dependencies by: * Removing `graphics.py`--moved `ProgressBar` to `training` and `mikkel_plot` to `lab`. From 81d58ad1e4652a849808246398b11f256faa9bce Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 08:20:09 +0000 Subject: [PATCH 07/24] Refactor __bool__ method in Wires class --- mrmustard/lab/wires.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index 48b49b12b..467795b70 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -205,7 +205,7 @@ def copy(self) -> Wires: return w def __bool__(self) -> bool: - return True if len(self.ids) > 0 else False + return len(self.ids) > 0 def __getitem__(self, modes: Iterable[int] | int) -> Wires: "A view of this Wires object with wires only on the given modes." From 362b1004b7c06898a1f8413aa421cbc6ec12d48a Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 08:21:39 +0000 Subject: [PATCH 08/24] Fix pointless statements in test cases --- tests/test_lab/test_wires.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_lab/test_wires.py b/tests/test_lab/test_wires.py index 6b123b587..20ca42849 100644 --- a/tests/test_lab/test_wires.py +++ b/tests/test_lab/test_wires.py @@ -73,7 +73,7 @@ def test_cant_add_overlapping_wires(): w1 = Wires([0], [1], [2], [3]) w2 = Wires([0], [2], [3], [4]) with pytest.raises(ValueError): - w1 >> w2 + w1 >> w2 # pylint: disable=pointless-statement def test_args(): @@ -93,4 +93,4 @@ def test_error_if_cant_contract(): u = Wires([], [], [0], []) # only output wire v = Wires([], [], [0], []) # only output wire with pytest.raises(ValueError): - u >> v + u >> v # pylint: disable=pointless-statement From 77d7429aca2eab241f1faf57467932f5d95f1c0e Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 08:23:20 +0000 Subject: [PATCH 09/24] Fix display issue in Wires class --- mrmustard/lab/wires.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index 467795b70..c9a70e284 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -312,11 +312,7 @@ def _repr_html_(self): # pragma: no cover html += "" try: - from IPython.core.display import ( - display, - HTML, - ) # pylint: disable=import-outside-toplevel - + from IPython.core.display import display, HTML # pylint: disable=import-outside-toplevel display(HTML(html)) except ImportError as e: raise ImportError( From 93c58c89fa27a45680abea161d115a39fd2e3d9c Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 08:41:54 +0000 Subject: [PATCH 10/24] Refactor wire connection logic in Wires class --- mrmustard/lab/wires.py | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index c9a70e284..ec2ab365d 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -215,6 +215,21 @@ def __getitem__(self, modes: Iterable[int] | int) -> Wires: def __lshift__(self, other: Wires) -> Wires: return (other.dual >> self.dual).dual # how cool is this + + @staticmethod + def _outin(si, so, oi, oo): + r"""Returns the output and input wires of the composite object made by connecting + two single-mode ket (or bra) objects: + --|self|-- --|other|-- + At this stage we are guaranteed that the configurations `|self|-- |other|--` and + `--|self| --|other|` (which would be invalid) have already been excluded. + """ + if bool(so) == bool(oi): # if the inner wires are either both there or both not there + return np.array([oo, si]) + elif not si and not so: # no wires on self + return np.array([oo, oi]) + else: # no wires on other + return np.array([so, si]) def __rshift__(self, other: Wires) -> Wires: r"""Returns a new Wires object with the wires of self and other combined as two @@ -227,30 +242,18 @@ def __rshift__(self, other: Wires) -> Wires: new_id_array = np.zeros((len(all_modes), 4), dtype=np.int64) for i, m in enumerate(all_modes): if m in self.modes and m in other.modes: - # m-th row of self and other (and self output bra = sob, etc...) - sob, sib, sok, sik = self._mode(m) - oob, oib, ook, oik = other._mode(m) + sob, sib, sok, sik = self._mode(m) # m-th row of self + oob, oib, ook, oik = other._mode(m) # m-th row of other errors = { - "output bra": sob and oob and not oib, - "output ket": sok and ook and not oik, - "input bra": oib and sib and not sob, - "input ket": oik and sik and not sok, + "output bra": sob and oob and not oib, # |s|- |o|- (bra) + "output ket": sok and ook and not oik, # |s|- |o|- (ket) + "input bra": oib and sib and not sob, # -|s| -|o| (bra) + "input ket": oik and sik and not sok, # -|s| -|o| (ket) } if any(errors.values()): position = [k for k, v in errors.items() if v][0] raise ValueError(f"{position} wire overlap at mode {m}") - if bool(sob) == bool(oib): # if the inner wires are both there or both not there - new_id_array[i] += np.array([oob, sib, 0, 0]) - elif not sib and not sob: # no wires on self - new_id_array[i] += np.array([oob, oib, 0, 0]) - else: # no wires on other - new_id_array[i] += np.array([sob, sib, 0, 0]) - if bool(sok) == bool(oik): # same as above but on the ket side - new_id_array[i] += np.array([0, 0, ook, sik]) - elif not sik and not sok: - new_id_array[i] += np.array([0, 0, ook, oik]) - else: - new_id_array[i] += np.array([0, 0, sok, sik]) + new_id_array[i] += np.hstack([self._outin(sib, sob, oib, oob), self._outin(sik, sok, oik, ook)]) elif m in self.modes and not m in other.modes: new_id_array[i] += self._mode(m) elif m in other.modes and not m in self.modes: From 84f7f04f652ef542efeaea5653f921f5686d0a24 Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 08:44:29 +0000 Subject: [PATCH 11/24] Fix whitespace issue in wires.py --- mrmustard/lab/wires.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index ec2ab365d..ca4270e64 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -215,7 +215,7 @@ def __getitem__(self, modes: Iterable[int] | int) -> Wires: def __lshift__(self, other: Wires) -> Wires: return (other.dual >> self.dual).dual # how cool is this - + @staticmethod def _outin(si, so, oi, oo): r"""Returns the output and input wires of the composite object made by connecting @@ -254,9 +254,9 @@ def __rshift__(self, other: Wires) -> Wires: position = [k for k, v in errors.items() if v][0] raise ValueError(f"{position} wire overlap at mode {m}") new_id_array[i] += np.hstack([self._outin(sib, sob, oib, oob), self._outin(sik, sok, oik, ook)]) - elif m in self.modes and not m in other.modes: + elif m in self.modes and m not in other.modes: new_id_array[i] += self._mode(m) - elif m in other.modes and not m in self.modes: + elif m in other.modes and m not in self.modes: new_id_array[i] += other._mode(m) return self._from_data(np.abs(new_id_array), all_modes) From 06226eeb2a3b5f3ff1f21370f2fdcf0f6d998cd1 Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 08:49:09 +0000 Subject: [PATCH 12/24] Fix data type in Wires class methods --- mrmustard/lab/wires.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index ca4270e64..919134748 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -225,11 +225,11 @@ def _outin(si, so, oi, oo): `--|self| --|other|` (which would be invalid) have already been excluded. """ if bool(so) == bool(oi): # if the inner wires are either both there or both not there - return np.array([oo, si]) + return np.array([oo, si], dtype=np.int64) elif not si and not so: # no wires on self - return np.array([oo, oi]) + return np.array([oo, oi], dtype=np.int64) else: # no wires on other - return np.array([so, si]) + return np.array([so, si], dtype=np.int64) def __rshift__(self, other: Wires) -> Wires: r"""Returns a new Wires object with the wires of self and other combined as two @@ -258,7 +258,7 @@ def __rshift__(self, other: Wires) -> Wires: new_id_array[i] += self._mode(m) elif m in other.modes and m not in self.modes: new_id_array[i] += other._mode(m) - return self._from_data(np.abs(new_id_array), all_modes) + return self._from_data(new_id_array, all_modes) # abs to turn hidden ids (negative) into visible def __repr__(self) -> str: ob_modes, ib_modes, ok_modes, ik_modes = self._args() From 466d676d52acd51f29c5f4064360b6c166f5be0d Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 08:51:34 +0000 Subject: [PATCH 13/24] Fix pylint warnings in test_wires.py --- tests/test_lab/test_wires.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_lab/test_wires.py b/tests/test_lab/test_wires.py index 20ca42849..1f470bc89 100644 --- a/tests/test_lab/test_wires.py +++ b/tests/test_lab/test_wires.py @@ -73,7 +73,7 @@ def test_cant_add_overlapping_wires(): w1 = Wires([0], [1], [2], [3]) w2 = Wires([0], [2], [3], [4]) with pytest.raises(ValueError): - w1 >> w2 # pylint: disable=pointless-statement + w1 >> w2 # pylint: disable=pointless-statement # type: ignore def test_args(): @@ -93,4 +93,4 @@ def test_error_if_cant_contract(): u = Wires([], [], [0], []) # only output wire v = Wires([], [], [0], []) # only output wire with pytest.raises(ValueError): - u >> v # pylint: disable=pointless-statement + u >> v # pylint: disable=pointless-statement # type: ignore From 4907ec0a72df2261851efd7c1f50fe553d0f1233 Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 08:55:04 +0000 Subject: [PATCH 14/24] Refactor __rshift__ method in Wires class --- mrmustard/lab/wires.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index 919134748..8b41e1651 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -219,8 +219,7 @@ def __lshift__(self, other: Wires) -> Wires: @staticmethod def _outin(si, so, oi, oo): r"""Returns the output and input wires of the composite object made by connecting - two single-mode ket (or bra) objects: - --|self|-- --|other|-- + two single-mode ket (or bra) objects like --|self|-- and --|other|-- At this stage we are guaranteed that the configurations `|self|-- |other|--` and `--|self| --|other|` (which would be invalid) have already been excluded. """ From 797ebc5508ae839e5615d0d0be424e21003d5172 Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 08:55:39 +0000 Subject: [PATCH 15/24] simplify errors raised in Wires class --- mrmustard/lab/wires.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index 8b41e1651..0e49b1dce 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -231,33 +231,26 @@ def _outin(si, so, oi, oo): return np.array([so, si], dtype=np.int64) def __rshift__(self, other: Wires) -> Wires: - r"""Returns a new Wires object with the wires of self and other combined as two - components in a circuit where the output of self connects to the input of other: - ``self >> other`` - All surviving wires are arranged in the standard order. - A ValueError is raised if there are any surviving wires that overlap, which is the only - possible way two objects aren't compatible in a circuit.""" all_modes = sorted(set(self.modes) | set(other.modes)) new_id_array = np.zeros((len(all_modes), 4), dtype=np.int64) + for i, m in enumerate(all_modes): if m in self.modes and m in other.modes: sob, sib, sok, sik = self._mode(m) # m-th row of self oob, oib, ook, oik = other._mode(m) # m-th row of other - errors = { - "output bra": sob and oob and not oib, # |s|- |o|- (bra) - "output ket": sok and ook and not oik, # |s|- |o|- (ket) - "input bra": oib and sib and not sob, # -|s| -|o| (bra) - "input ket": oik and sik and not sok, # -|s| -|o| (ket) - } - if any(errors.values()): - position = [k for k, v in errors.items() if v][0] - raise ValueError(f"{position} wire overlap at mode {m}") + + if (sob and oob and not oib) or (sok and ook and not oik): + raise ValueError(f"Output wire overlap at mode {m}") + elif (oib and sib and not sob) or (oik and sik and not sok): + raise ValueError(f"Input wire overlap at mode {m}") + new_id_array[i] += np.hstack([self._outin(sib, sob, oib, oob), self._outin(sik, sok, oik, ook)]) - elif m in self.modes and m not in other.modes: + elif m in self.modes: new_id_array[i] += self._mode(m) - elif m in other.modes and m not in self.modes: + elif m in other.modes: new_id_array[i] += other._mode(m) - return self._from_data(new_id_array, all_modes) # abs to turn hidden ids (negative) into visible + + return self._from_data(new_id_array, all_modes) def __repr__(self) -> str: ob_modes, ib_modes, ok_modes, ik_modes = self._args() From 247a34fd09f29c158aec88068ea77988f89e3957 Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 09:09:14 +0000 Subject: [PATCH 16/24] Refactor wire mode merging logic --- mrmustard/lab/wires.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index 0e49b1dce..2481c28f7 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -234,21 +234,20 @@ def __rshift__(self, other: Wires) -> Wires: all_modes = sorted(set(self.modes) | set(other.modes)) new_id_array = np.zeros((len(all_modes), 4), dtype=np.int64) - for i, m in enumerate(all_modes): - if m in self.modes and m in other.modes: - sob, sib, sok, sik = self._mode(m) # m-th row of self - oob, oib, ook, oik = other._mode(m) # m-th row of other + for m in set(self.modes) & set(other.modes): + sob, sib, sok, sik = self._mode(m) # m-th row of self + oob, oib, ook, oik = other._mode(m) # m-th row of other + + if (sob and oob and not oib) or (sok and ook and not oik): + raise ValueError(f"Output wire overlap at mode {m}") + if (oib and sib and not sob) or (oik and sik and not sok): + raise ValueError(f"Input wire overlap at mode {m}") - if (sob and oob and not oib) or (sok and ook and not oik): - raise ValueError(f"Output wire overlap at mode {m}") - elif (oib and sib and not sob) or (oik and sik and not sok): - raise ValueError(f"Input wire overlap at mode {m}") - - new_id_array[i] += np.hstack([self._outin(sib, sob, oib, oob), self._outin(sik, sok, oik, ook)]) - elif m in self.modes: - new_id_array[i] += self._mode(m) - elif m in other.modes: - new_id_array[i] += other._mode(m) + new_id_array[all_modes.index(m)] = np.hstack([self._outin(sib, sob, oib, oob), self._outin(sik, sok, oik, ook)]) + for m in set(self.modes) - set(other.modes): + new_id_array[all_modes.index(m)] = self._mode(m) + for m in set(other.modes) - set(self.modes): + new_id_array[all_modes.index(m)] = other._mode(m) return self._from_data(new_id_array, all_modes) From ebb95f5f5cd8612cec700915b3b59407bb80074f Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 09:09:54 +0000 Subject: [PATCH 17/24] Fix typo in comment --- mrmustard/lab/wires.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index 2481c28f7..b5a55e4ed 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -219,7 +219,7 @@ def __lshift__(self, other: Wires) -> Wires: @staticmethod def _outin(si, so, oi, oo): r"""Returns the output and input wires of the composite object made by connecting - two single-mode ket (or bra) objects like --|self|-- and --|other|-- + two single-mode (ket or bra) objects like --|self|-- and --|other|-- At this stage we are guaranteed that the configurations `|self|-- |other|--` and `--|self| --|other|` (which would be invalid) have already been excluded. """ From d7c219f60ba411cc388e93c95cec03bf324ace42 Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 09:43:12 +0000 Subject: [PATCH 18/24] Fix formatting issues in Wires class --- mrmustard/lab/wires.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index b5a55e4ed..e3fa3b109 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -220,7 +220,7 @@ def __lshift__(self, other: Wires) -> Wires: def _outin(si, so, oi, oo): r"""Returns the output and input wires of the composite object made by connecting two single-mode (ket or bra) objects like --|self|-- and --|other|-- - At this stage we are guaranteed that the configurations `|self|-- |other|--` and + At this stage we are guaranteed that the configurations `|self|-- |other|--` and `--|self| --|other|` (which would be invalid) have already been excluded. """ if bool(so) == bool(oi): # if the inner wires are either both there or both not there @@ -233,7 +233,7 @@ def _outin(si, so, oi, oo): def __rshift__(self, other: Wires) -> Wires: all_modes = sorted(set(self.modes) | set(other.modes)) new_id_array = np.zeros((len(all_modes), 4), dtype=np.int64) - + for m in set(self.modes) & set(other.modes): sob, sib, sok, sik = self._mode(m) # m-th row of self oob, oib, ook, oik = other._mode(m) # m-th row of other @@ -242,13 +242,15 @@ def __rshift__(self, other: Wires) -> Wires: raise ValueError(f"Output wire overlap at mode {m}") if (oib and sib and not sob) or (oik and sik and not sok): raise ValueError(f"Input wire overlap at mode {m}") - - new_id_array[all_modes.index(m)] = np.hstack([self._outin(sib, sob, oib, oob), self._outin(sik, sok, oik, ook)]) + + new_id_array[all_modes.index(m)] = np.hstack( + [self._outin(sib, sob, oib, oob), self._outin(sik, sok, oik, ook)] + ) for m in set(self.modes) - set(other.modes): new_id_array[all_modes.index(m)] = self._mode(m) for m in set(other.modes) - set(self.modes): new_id_array[all_modes.index(m)] = other._mode(m) - + return self._from_data(new_id_array, all_modes) def __repr__(self) -> str: @@ -306,7 +308,11 @@ def _repr_html_(self): # pragma: no cover html += "" try: - from IPython.core.display import display, HTML # pylint: disable=import-outside-toplevel + from IPython.core.display import ( + display, + HTML, + ) # pylint: disable=import-outside-toplevel + display(HTML(html)) except ImportError as e: raise ImportError( From a5edc3253feaf62d6f7f616c8df69f9cec8aa34e Mon Sep 17 00:00:00 2001 From: ziofil Date: Mon, 29 Jan 2024 09:47:18 +0000 Subject: [PATCH 19/24] Fix import error in Wires class --- mrmustard/lab/wires.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index e3fa3b109..6f5d2a8d2 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -17,11 +17,13 @@ from __future__ import annotations from typing import Iterable, Optional + import numpy as np + from mrmustard import settings # pylint: disable=protected-access - +# pylint: disable=import-outside-toplevel class Wires: r"""`Wires` class for handling the connectivity of an object in a circuit. @@ -309,9 +311,9 @@ def _repr_html_(self): # pragma: no cover try: from IPython.core.display import ( - display, HTML, - ) # pylint: disable=import-outside-toplevel + display, + ) display(HTML(html)) except ImportError as e: From 0e9e04c4bf75652add6d03d231da58a52cc7cf31 Mon Sep 17 00:00:00 2001 From: ziofil Date: Tue, 30 Jan 2024 18:04:59 +0000 Subject: [PATCH 20/24] blacked --- mrmustard/lab/wires.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index 6f5d2a8d2..8d27e494c 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -25,6 +25,7 @@ # pylint: disable=protected-access # pylint: disable=import-outside-toplevel + class Wires: r"""`Wires` class for handling the connectivity of an object in a circuit. In MrMustard, ``CircuitComponent``s have a ``Wires`` object as attribute From eddf339b68d619c102f3dd0727fbdf8c3a0d16c5 Mon Sep 17 00:00:00 2001 From: ziofil Date: Tue, 30 Jan 2024 18:07:09 +0000 Subject: [PATCH 21/24] simplify attempt (for codecov) --- mrmustard/lab/wires.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab/wires.py index 8d27e494c..9f4107622 100644 --- a/mrmustard/lab/wires.py +++ b/mrmustard/lab/wires.py @@ -241,9 +241,13 @@ def __rshift__(self, other: Wires) -> Wires: sob, sib, sok, sik = self._mode(m) # m-th row of self oob, oib, ook, oik = other._mode(m) # m-th row of other - if (sob and oob and not oib) or (sok and ook and not oik): + out_bra_issue = sob and oob and not oib + out_ket_issue = sok and ook and not oik + if out_bra_issue or out_ket_issue: raise ValueError(f"Output wire overlap at mode {m}") - if (oib and sib and not sob) or (oik and sik and not sok): + in_bra_issue = oib and sib and not sob + in_ket_issue = oik and sik and not sok + if in_bra_issue or in_ket_issue: raise ValueError(f"Input wire overlap at mode {m}") new_id_array[all_modes.index(m)] = np.hstack( From a75934322b801e3d857f17d7238b916ea682b90d Mon Sep 17 00:00:00 2001 From: ziofil Date: Tue, 30 Jan 2024 18:27:05 +0000 Subject: [PATCH 22/24] trick codefactor --- mrmustard/lab/abstract/state.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mrmustard/lab/abstract/state.py b/mrmustard/lab/abstract/state.py index f914fdbb6..25b2ab177 100644 --- a/mrmustard/lab/abstract/state.py +++ b/mrmustard/lab/abstract/state.py @@ -591,7 +591,6 @@ def get_modes(self, item) -> State: fock_partitioned = fock.trace(self.dm(self.cutoffs), keep=item_idx) return State(dm=fock_partitioned, modes=item) - # TODO: refactor def __eq__(self, other) -> bool: # pylint: disable=too-many-return-statements r"""Returns whether the states are equal.""" if self.num_modes != other.num_modes: From a2af8c3762abed604914973424a8421cc254b883 Mon Sep 17 00:00:00 2001 From: ziofil Date: Tue, 30 Jan 2024 18:58:51 +0000 Subject: [PATCH 23/24] moved wires --- mrmustard/{lab => lab_dev}/wires.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename mrmustard/{lab => lab_dev}/wires.py (100%) diff --git a/mrmustard/lab/wires.py b/mrmustard/lab_dev/wires.py similarity index 100% rename from mrmustard/lab/wires.py rename to mrmustard/lab_dev/wires.py From ddec32232226025621d286dbd698df0a2786fad4 Mon Sep 17 00:00:00 2001 From: ziofil Date: Tue, 30 Jan 2024 19:03:18 +0000 Subject: [PATCH 24/24] updated import and moved test_wires.py --- tests/{test_lab => test_lab_dev}/test_wires.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/{test_lab => test_lab_dev}/test_wires.py (98%) diff --git a/tests/test_lab/test_wires.py b/tests/test_lab_dev/test_wires.py similarity index 98% rename from tests/test_lab/test_wires.py rename to tests/test_lab_dev/test_wires.py index 1f470bc89..cc4921fc9 100644 --- a/tests/test_lab/test_wires.py +++ b/tests/test_lab_dev/test_wires.py @@ -18,7 +18,7 @@ # pylint: disable=protected-access import pytest -from mrmustard.lab.wires import Wires +from mrmustard.lab_dev.wires import Wires def test_wires_view_has_same_ids():