diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 2458d4c80..99437fa3f 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -15,12 +15,15 @@ * Added an Ansatz abstract class and PolyExpAnsatz concrete implementation. This is used in the Bargmann representation. [(#295)](https://github.com/XanaduAI/MrMustard/pull/295) -* Added `complex_gaussian_integral` and `real_gaussian_integral` methods. +* Added `complex_gaussian_integral` method. [(#295)](https://github.com/XanaduAI/MrMustard/pull/295) * Added `Bargmann` representation (parametrized by Abc). Supports all algebraic operations and CV (exact) inner product. [(#296)](https://github.com/XanaduAI/MrMustard/pull/296) +* Added a new class `Wires` in `mrmustard.lab` to handle the connectivity of objects in a circuit. + [(#330)](https://github.com/XanaduAI/MrMustard/pull/330) + ### Breaking changes * Removed circular dependencies by: * Removing `graphics.py`--moved `ProgressBar` to `training` and `mikkel_plot` to `lab`. diff --git a/mrmustard/lab/abstract/measurement.py b/mrmustard/lab/abstract/measurement.py index b0f017281..5f3d87bfd 100644 --- a/mrmustard/lab/abstract/measurement.py +++ b/mrmustard/lab/abstract/measurement.py @@ -87,12 +87,10 @@ def outcome(self): ... @abstractmethod - def _measure_fock(self, other: State) -> Union[State, float]: - ... + def _measure_fock(self, other: State) -> Union[State, float]: ... @abstractmethod - def _measure_gaussian(self, other: State) -> Union[State, float]: - ... + def _measure_gaussian(self, other: State) -> Union[State, float]: ... def primal(self, other: State) -> Union[State, float]: """performs the measurement procedure according to the representation of the incoming state""" diff --git a/mrmustard/lab/abstract/state.py b/mrmustard/lab/abstract/state.py index f914fdbb6..25b2ab177 100644 --- a/mrmustard/lab/abstract/state.py +++ b/mrmustard/lab/abstract/state.py @@ -591,7 +591,6 @@ def get_modes(self, item) -> State: fock_partitioned = fock.trace(self.dm(self.cutoffs), keep=item_idx) return State(dm=fock_partitioned, modes=item) - # TODO: refactor def __eq__(self, other) -> bool: # pylint: disable=too-many-return-statements r"""Returns whether the states are equal.""" if self.num_modes != other.num_modes: diff --git a/mrmustard/lab/gates.py b/mrmustard/lab/gates.py index aadba4bde..374cf39df 100644 --- a/mrmustard/lab/gates.py +++ b/mrmustard/lab/gates.py @@ -1041,9 +1041,7 @@ def primal(self, state): coeff = math.cast( math.exp( - -0.5 - * self.phase_stdev.value**2 - * math.arange(-dm.shape[-2] + 1, dm.shape[-1]) ** 2 + -0.5 * self.phase_stdev.value**2 * math.arange(-dm.shape[-2] + 1, dm.shape[-1]) ** 2 ), dm.dtype, ) diff --git a/mrmustard/lab_dev/wires.py b/mrmustard/lab_dev/wires.py new file mode 100644 index 000000000..9f4107622 --- /dev/null +++ b/mrmustard/lab_dev/wires.py @@ -0,0 +1,327 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""`Wires` class for handling the connectivity of an object in a circuit.""" + +from __future__ import annotations + +from typing import Iterable, Optional + +import numpy as np + +from mrmustard import settings + +# pylint: disable=protected-access +# pylint: disable=import-outside-toplevel + + +class Wires: + r"""`Wires` class for handling the connectivity of an object in a circuit. + In MrMustard, ``CircuitComponent``s have a ``Wires`` object as attribute + to handle the wires of the component and to connect components together. + + Wires are arranged into four groups, and each of the groups can + span multiple modes: + _______________ + input bra modes --->| circuit |---> output bra modes + input ket modes --->| component |---> output ket modes + --------------- + Each of the four groups can be empty. In particular, the wires of a state have + no inputs and the wires of a measurement have no outputs. Similarly, + a standard unitary transformation has no bra wires. + + We refer to these four groups in a "standard order": + 0. output_bra for all modes + 1. input_bra for all modes + 2. output_ket for all modes + 3. input_ket for all modes + + A ``Wires`` object can return subsets (views) of itself. Available subsets are: + - input/output (wires on input/output side) + - bra/ket (wires on bra/ket side) + - modes (wires on the given modes) + + For example, ``wires.input`` returns a ``Wires`` object with only the input wires + (on bra and ket sides and on all the modes). + Note that these can be combined together: ``wires.input.bra[(1,2)] returns a + ``Wires`` object with only the input bra wires on modes 1 and 2. + Note these are views of the original ``Wires`` object, i.e. we can set the ``ids`` on the + views and it will be set on the original, e.g. this is a valid way to connect two sets + of wires by setting their ids to be equal: ``wires1.output.ids = wires2.input.ids``. + + A very useful feature of the ``Wires`` class is the support for the right shift + operator ``>>``. This allows us to connect two ``Wires`` objects together as + ``wires1 >> wires2``. This will return the ``Wires`` object of the two + wires objects connected together as if they were in a circuit: + ____________ ____________ + in bra --->| wires1 |---> out bra ---> in bra --->| wires2 |---> out bra + in ket --->| |---> out ket ---> in ket --->| |---> out ket + ------------ ------------ + The returned ``Wires`` object will contain the surviving wires of the two + ``Wires`` objects and it will raise an error if there are overlaps between + the surviving wires. This is especially useful for handling the ordering of the + wires when connecting components together: we are always guaranteed that a + ``Wires`` object will provide the wires in the standard order. + + Args: + modes_out_bra (Iterable[int]): The output modes on the bra side. + modes_in_bra (Iterable[int]): The input modes on the bra side. + modes_out_ket (Iterable[int]): The output modes on the ket side. + modes_in_ket (Iterable[int]): The input modes on the ket side. + + Note that the order of the modes passed to initialize the object doesn't matter, + as they get sorted at init time. + """ + + def __init__( + self, + modes_out_bra: Optional[Iterable[int]] = None, + modes_in_bra: Optional[Iterable[int]] = None, + modes_out_ket: Optional[Iterable[int]] = None, + modes_in_ket: Optional[Iterable[int]] = None, + ) -> None: + modes_out_bra = modes_out_bra or [] + modes_in_bra = modes_in_bra or [] + modes_out_ket = modes_out_ket or [] + modes_in_ket = modes_in_ket or [] + + self._modes = sorted( + set(modes_out_bra) | set(modes_in_bra) | set(modes_out_ket) | set(modes_in_ket) + ) + randint = settings.rng.integers # MM random number generator + outbra = {m: randint(1, 2**62) if m in modes_out_bra else 0 for m in self._modes} + inbra = {m: randint(1, 2**62) if m in modes_in_bra else 0 for m in self._modes} + outket = {m: randint(1, 2**62) if m in modes_out_ket else 0 for m in self._modes} + inket = {m: randint(1, 2**62) if m in modes_in_ket else 0 for m in self._modes} + self._id_array = np.array([[outbra[m], inbra[m], outket[m], inket[m]] for m in self._modes]) + self._mask = np.ones_like(self._id_array) # multiplicative mask + + def _args(self): + r"""Returns the same args one needs to initialize this object.""" + ob_modes = np.array(self._modes)[self._id_array[:, 0] > 0].tolist() + ib_modes = np.array(self._modes)[self._id_array[:, 1] > 0].tolist() + ok_modes = np.array(self._modes)[self._id_array[:, 2] > 0].tolist() + ik_modes = np.array(self._modes)[self._id_array[:, 3] > 0].tolist() + return tuple(ob_modes), tuple(ib_modes), tuple(ok_modes), tuple(ik_modes) + + @classmethod + def _from_data(cls, id_array: np.ndarray, modes: list[int], mask=None): + r"""Private class method to initialize Wires object from the given data.""" + w = cls() + w._id_array = id_array + w._modes = modes + w._mask = mask if mask is not None else np.ones_like(w._id_array) + return w + + def _view(self, masked_rows: tuple[int, ...] = (), masked_cols: tuple[int, ...] = ()) -> Wires: + r"""A masked view of this Wires object.""" + w = self._from_data(self._id_array, self._modes, self._mask.copy()) + w._mask[masked_rows, :] = -1 + w._mask[:, masked_cols] = -1 + return w + + def _mode(self, mode: int) -> np.ndarray: + "A slice of the id_array matrix at the given mode." + return np.maximum(0, self.id_array[[self._modes.index(mode)]])[0] + + @property + def id_array(self) -> np.ndarray: + "The id_array of the available wires in the standard order (bra/ket x out/in x mode)." + return self._id_array * self._mask + + @property + def ids(self) -> list[int]: + "The list of ids of the available wires in the standard order." + flat = self.id_array.T.ravel() + return flat[flat > 0].tolist() + + @ids.setter + def ids(self, ids: list[int]): + "Sets the ids of the available wires." + if len(ids) != len(self.ids): + raise ValueError(f"wrong number of ids (expected {len(self.ids)}, got {len(ids)})") + self._id_array.flat[self.id_array.flatten() > 0] = ids + + @property + def modes(self) -> list[int]: + "The set of modes spanned by the populated wires." + return [m for m in self._modes if any(self.id_array[self._modes.index(m)] > 0)] + + @property + def indices(self) -> list[int]: + r"""Returns the array of indices of this subset in the standard order. + (bra/ket x out/in x mode). Use this to get the indices for bargmann contractions. + """ + flat = self.id_array.T.ravel() + flat = flat[flat != 0] + return np.where(flat > 0)[0].tolist() + + @property + def input(self) -> Wires: + "A view of self without output wires" + return self._view(masked_cols=(0, 2)) + + @property + def output(self) -> Wires: + "A view of self without input wires" + return self._view(masked_cols=(1, 3)) + + @property + def ket(self) -> Wires: + "A view of self without bra wires" + return self._view(masked_cols=(0, 1)) + + @property + def bra(self) -> Wires: + "A view of self without ket wires" + return self._view(masked_cols=(2, 3)) + + @property + def adjoint(self) -> Wires: + r""" + The adjoint (ket <-> bra) of this wires object. + """ + return self._from_data(self._id_array[:, [2, 3, 0, 1]], self._modes, self._mask) + + @property + def dual(self) -> Wires: + r""" + The dual (in <-> out) of this wires object. + """ + return self._from_data(self._id_array[:, [1, 0, 3, 2]], self._modes, self._mask) + + def copy(self) -> Wires: + r"""A copy of this Wires object with new ids.""" + w = Wires(*self._args()) + w._mask = self._mask.copy() + return w + + def __bool__(self) -> bool: + return len(self.ids) > 0 + + def __getitem__(self, modes: Iterable[int] | int) -> Wires: + "A view of this Wires object with wires only on the given modes." + modes = [modes] if isinstance(modes, int) else modes + idxs = tuple(list(self._modes).index(m) for m in set(self._modes).difference(modes)) + return self._view(masked_rows=idxs) + + def __lshift__(self, other: Wires) -> Wires: + return (other.dual >> self.dual).dual # how cool is this + + @staticmethod + def _outin(si, so, oi, oo): + r"""Returns the output and input wires of the composite object made by connecting + two single-mode (ket or bra) objects like --|self|-- and --|other|-- + At this stage we are guaranteed that the configurations `|self|-- |other|--` and + `--|self| --|other|` (which would be invalid) have already been excluded. + """ + if bool(so) == bool(oi): # if the inner wires are either both there or both not there + return np.array([oo, si], dtype=np.int64) + elif not si and not so: # no wires on self + return np.array([oo, oi], dtype=np.int64) + else: # no wires on other + return np.array([so, si], dtype=np.int64) + + def __rshift__(self, other: Wires) -> Wires: + all_modes = sorted(set(self.modes) | set(other.modes)) + new_id_array = np.zeros((len(all_modes), 4), dtype=np.int64) + + for m in set(self.modes) & set(other.modes): + sob, sib, sok, sik = self._mode(m) # m-th row of self + oob, oib, ook, oik = other._mode(m) # m-th row of other + + out_bra_issue = sob and oob and not oib + out_ket_issue = sok and ook and not oik + if out_bra_issue or out_ket_issue: + raise ValueError(f"Output wire overlap at mode {m}") + in_bra_issue = oib and sib and not sob + in_ket_issue = oik and sik and not sok + if in_bra_issue or in_ket_issue: + raise ValueError(f"Input wire overlap at mode {m}") + + new_id_array[all_modes.index(m)] = np.hstack( + [self._outin(sib, sob, oib, oob), self._outin(sik, sok, oik, ook)] + ) + for m in set(self.modes) - set(other.modes): + new_id_array[all_modes.index(m)] = self._mode(m) + for m in set(other.modes) - set(self.modes): + new_id_array[all_modes.index(m)] = other._mode(m) + + return self._from_data(new_id_array, all_modes) + + def __repr__(self) -> str: + ob_modes, ib_modes, ok_modes, ik_modes = self._args() + return f"Wires({ob_modes}, {ib_modes}, {ok_modes}, {ik_modes})" + + def _repr_html_(self): # pragma: no cover + "A matrix plot of the id_array." + row_labels = map(str, self._modes) + col_labels = ["bra-out", "bra-in", "ket-out", "ket-in"] + array = np.abs(self.id_array) / (self.id_array + 1e-15) + idxs = (i for i in self.indices) + box_size = "60px" # Set the size of the squares + html = '' + # colors + active = "#5b9bd5" + inactive = "#d6e8f7" + + # Add column headers + html += "" + for label in [""] + col_labels: # Add an empty string for the top-left cell + html += f'' + html += "" + + # Initialize rows with row labels + rows_html = [ + f'' + for label in row_labels + ] + + # Add table cells (column by column) + for label, col in zip(col_labels, array.T): + for row_idx, value in enumerate(col): + color = ( + "white" + if np.isclose(value, 0) + else (active if np.isclose(value, 1) else inactive) + ) + cell_html = f'' + ) + else: + cell_html += '">' + rows_html[row_idx] += cell_html + + # Close the rows and add them to the HTML table + for row_html in rows_html: + row_html += "" + html += row_html + + html += "
{label}
{label}{str(next(idxs))}
" + + try: + from IPython.core.display import ( + HTML, + display, + ) + + display(HTML(html)) + except ImportError as e: + raise ImportError( + "To display the wires in a jupyter notebook you need to `pip install IPython` first." + ) from e diff --git a/mrmustard/training/callbacks.py b/mrmustard/training/callbacks.py index a343fd94d..a5e8a7874 100644 --- a/mrmustard/training/callbacks.py +++ b/mrmustard/training/callbacks.py @@ -254,9 +254,9 @@ def call( orig_cost = np.array(optimizer.callback_history["orig_cost"][-1]).item() obj_scalars[f"{obj_tag}/orig_cost"] = orig_cost if self.cost_converter is not None: - obj_scalars[ - f"{obj_tag}/{self.cost_converter.__name__}(orig_cost)" - ] = self.cost_converter(orig_cost) + obj_scalars[f"{obj_tag}/{self.cost_converter.__name__}(orig_cost)"] = ( + self.cost_converter(orig_cost) + ) for k, v in obj_scalars.items(): tf.summary.scalar(k, data=v, step=self.optimizer_step) diff --git a/mrmustard/utils/typing.py b/mrmustard/utils/typing.py index 36cf64bd9..84a157784 100644 --- a/mrmustard/utils/typing.py +++ b/mrmustard/utils/typing.py @@ -95,5 +95,4 @@ class Batch(Protocol[T_co]): r"""Anything that can iterate over objects of type T_co.""" - def __iter__(self) -> Iterator[T_co]: - ... + def __iter__(self) -> Iterator[T_co]: ... diff --git a/tests/test_lab_dev/test_wires.py b/tests/test_lab_dev/test_wires.py new file mode 100644 index 000000000..cc4921fc9 --- /dev/null +++ b/tests/test_lab_dev/test_wires.py @@ -0,0 +1,96 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Wires class.""" + +# pylint: disable=missing-function-docstring +# pylint: disable=protected-access + +import pytest +from mrmustard.lab_dev.wires import Wires + + +def test_wires_view_has_same_ids(): + w = Wires([0], [0], [0], [0]) + assert set(w.ids) == set(w._view().ids) + + +def test_view_can_edit_original(): + w = Wires([0], [0], [0], [0]) + w._view().ids = [9, 99, 999, 9999] + assert w.ids == [9, 99, 999, 9999] + + +def test_wire_subsets(): + w = Wires([0], [1], [2], [3]) + assert w.output.bra.modes == [0] + assert w.input.bra.modes == [1] + assert w.output.ket.modes == [2] + assert w.input.ket.modes == [3] + + +def test_wire_mode_subsets(): + w = Wires([10], [11], [12], [13]) + assert w[10].ids == w.output.bra.ids + assert w[11].ids == w.input.bra.ids + assert w[12].ids == w.output.ket.ids + assert w[13].ids == w.input.ket.ids + + +def test_indices(): + w = Wires([0, 1, 2], [3, 4, 5], [6, 7], [8]) + assert w.output.indices == [0, 1, 2, 6, 7] + assert w.bra.indices == [0, 1, 2, 3, 4, 5] + assert w.input.indices == [3, 4, 5, 8] + assert w.ket.indices == [6, 7, 8] + + +def test_setting_ids(): + w = Wires([0], [0], [0], [0]) + w.ids = [9, 99, 999, 9999] + assert w.ids == [9, 99, 999, 9999] + + +def test_non_overlapping_wires(): + w1 = Wires([0], [1], [2], [3]) + w2 = Wires([1], [2], [3], [4]) + w12 = Wires([0, 1], [1, 2], [2, 3], [3, 4]) + assert (w1 >> w2).modes == w12.modes + + +def test_cant_add_overlapping_wires(): + w1 = Wires([0], [1], [2], [3]) + w2 = Wires([0], [2], [3], [4]) + with pytest.raises(ValueError): + w1 >> w2 # pylint: disable=pointless-statement # type: ignore + + +def test_args(): + w = Wires([0], [1], [2], [3]) + assert w._args() == ((0,), (1,), (2,), (3,)) + + +def test_right_shift_general_contraction(): + # contracts 1,1 on bra side + # contracts 3,3 and 13,13 on ket side (note order doesn't matter) + u = Wires([1, 5], [2, 6, 15], [3, 7, 13], [4, 8]) + v = Wires([0, 9, 14], [1, 10], [2, 11], [13, 3, 12]) + assert (u >> v)._args() == ((0, 5, 9, 14), (2, 6, 10, 15), (2, 7, 11), (4, 8, 12)) + + +def test_error_if_cant_contract(): + u = Wires([], [], [0], []) # only output wire + v = Wires([], [], [0], []) # only output wire + with pytest.raises(ValueError): + u >> v # pylint: disable=pointless-statement # type: ignore diff --git a/tests/test_math/test_backend_manager.py b/tests/test_math/test_backend_manager.py index 963ca5cc9..64935f3c0 100644 --- a/tests/test_math/test_backend_manager.py +++ b/tests/test_math/test_backend_manager.py @@ -29,6 +29,7 @@ class TestBackendManager: r""" Tests the BackendManager. """ + l1 = [1.0] l2 = [1.0 + 0.0j, -2.0 + 2.0j] l3 = [[1.0, 2.0], [-3.0, 4.0]] diff --git a/tests/test_math/test_compactFock.py b/tests/test_math/test_compactFock.py index dd51e51fe..48761e5ab 100644 --- a/tests/test_math/test_compactFock.py +++ b/tests/test_math/test_compactFock.py @@ -1,6 +1,7 @@ """ Unit tests for mrmustard.math.compactFock.compactFock~ """ + import importlib import numpy as np