Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wires final #333

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
* Added an Ansatz abstract class and PolyExpAnsatz concrete implementation. This is used in the Bargmann representation.
[(#295)](https://github.com/XanaduAI/MrMustard/pull/295)

* Added `complex_gaussian_integral` and `real_gaussian_integral` methods.
* Added `complex_gaussian_integral` method.
[(#295)](https://github.com/XanaduAI/MrMustard/pull/295)

* Added `Bargmann` representation (parametrized by Abc). Supports all algebraic operations and CV (exact) inner product.
[(#296)](https://github.com/XanaduAI/MrMustard/pull/296)

* Added a new class `Wires` in `mrmustard.lab` to handle the connectivity of objects in a circuit.
[(#330)](https://github.com/XanaduAI/MrMustard/pull/330)

### Breaking changes
* Removed circular dependencies by:
* Removing `graphics.py`--moved `ProgressBar` to `training` and `mikkel_plot` to `lab`.
Expand Down
6 changes: 2 additions & 4 deletions mrmustard/lab/abstract/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,10 @@ def outcome(self):
...

@abstractmethod
def _measure_fock(self, other: State) -> Union[State, float]:
...
def _measure_fock(self, other: State) -> Union[State, float]: ...

@abstractmethod
def _measure_gaussian(self, other: State) -> Union[State, float]:
...
def _measure_gaussian(self, other: State) -> Union[State, float]: ...

def primal(self, other: State) -> Union[State, float]:
"""performs the measurement procedure according to the representation of the incoming state"""
Expand Down
1 change: 0 additions & 1 deletion mrmustard/lab/abstract/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,6 @@ def get_modes(self, item) -> State:
fock_partitioned = fock.trace(self.dm(self.cutoffs), keep=item_idx)
return State(dm=fock_partitioned, modes=item)

# TODO: refactor
def __eq__(self, other) -> bool: # pylint: disable=too-many-return-statements
r"""Returns whether the states are equal."""
if self.num_modes != other.num_modes:
Expand Down
4 changes: 1 addition & 3 deletions mrmustard/lab/gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,9 +1041,7 @@ def primal(self, state):

coeff = math.cast(
math.exp(
-0.5
* self.phase_stdev.value**2
* math.arange(-dm.shape[-2] + 1, dm.shape[-1]) ** 2
-0.5 * self.phase_stdev.value**2 * math.arange(-dm.shape[-2] + 1, dm.shape[-1]) ** 2
),
dm.dtype,
)
Expand Down
327 changes: 327 additions & 0 deletions mrmustard/lab_dev/wires.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""`Wires` class for handling the connectivity of an object in a circuit."""

from __future__ import annotations

from typing import Iterable, Optional

import numpy as np

from mrmustard import settings

# pylint: disable=protected-access
# pylint: disable=import-outside-toplevel


class Wires:
r"""`Wires` class for handling the connectivity of an object in a circuit.
In MrMustard, ``CircuitComponent``s have a ``Wires`` object as attribute
to handle the wires of the component and to connect components together.

Wires are arranged into four groups, and each of the groups can
span multiple modes:
_______________
input bra modes --->| circuit |---> output bra modes
input ket modes --->| component |---> output ket modes
---------------
Each of the four groups can be empty. In particular, the wires of a state have
no inputs and the wires of a measurement have no outputs. Similarly,
a standard unitary transformation has no bra wires.

We refer to these four groups in a "standard order":
0. output_bra for all modes
1. input_bra for all modes
2. output_ket for all modes
3. input_ket for all modes

A ``Wires`` object can return subsets (views) of itself. Available subsets are:
- input/output (wires on input/output side)
- bra/ket (wires on bra/ket side)
- modes (wires on the given modes)

For example, ``wires.input`` returns a ``Wires`` object with only the input wires
(on bra and ket sides and on all the modes).
Note that these can be combined together: ``wires.input.bra[(1,2)] returns a
``Wires`` object with only the input bra wires on modes 1 and 2.
Note these are views of the original ``Wires`` object, i.e. we can set the ``ids`` on the
views and it will be set on the original, e.g. this is a valid way to connect two sets
of wires by setting their ids to be equal: ``wires1.output.ids = wires2.input.ids``.

A very useful feature of the ``Wires`` class is the support for the right shift
operator ``>>``. This allows us to connect two ``Wires`` objects together as
``wires1 >> wires2``. This will return the ``Wires`` object of the two
wires objects connected together as if they were in a circuit:
____________ ____________
in bra --->| wires1 |---> out bra ---> in bra --->| wires2 |---> out bra
in ket --->| |---> out ket ---> in ket --->| |---> out ket
------------ ------------
The returned ``Wires`` object will contain the surviving wires of the two
``Wires`` objects and it will raise an error if there are overlaps between
the surviving wires. This is especially useful for handling the ordering of the
wires when connecting components together: we are always guaranteed that a
``Wires`` object will provide the wires in the standard order.

Args:
modes_out_bra (Iterable[int]): The output modes on the bra side.
modes_in_bra (Iterable[int]): The input modes on the bra side.
modes_out_ket (Iterable[int]): The output modes on the ket side.
modes_in_ket (Iterable[int]): The input modes on the ket side.

Note that the order of the modes passed to initialize the object doesn't matter,
as they get sorted at init time.
"""

def __init__(
self,
modes_out_bra: Optional[Iterable[int]] = None,
modes_in_bra: Optional[Iterable[int]] = None,
modes_out_ket: Optional[Iterable[int]] = None,
modes_in_ket: Optional[Iterable[int]] = None,
) -> None:
modes_out_bra = modes_out_bra or []
modes_in_bra = modes_in_bra or []
modes_out_ket = modes_out_ket or []
modes_in_ket = modes_in_ket or []

self._modes = sorted(
set(modes_out_bra) | set(modes_in_bra) | set(modes_out_ket) | set(modes_in_ket)
)
randint = settings.rng.integers # MM random number generator
outbra = {m: randint(1, 2**62) if m in modes_out_bra else 0 for m in self._modes}
inbra = {m: randint(1, 2**62) if m in modes_in_bra else 0 for m in self._modes}
outket = {m: randint(1, 2**62) if m in modes_out_ket else 0 for m in self._modes}
inket = {m: randint(1, 2**62) if m in modes_in_ket else 0 for m in self._modes}
self._id_array = np.array([[outbra[m], inbra[m], outket[m], inket[m]] for m in self._modes])
self._mask = np.ones_like(self._id_array) # multiplicative mask

def _args(self):
r"""Returns the same args one needs to initialize this object."""
ob_modes = np.array(self._modes)[self._id_array[:, 0] > 0].tolist()
ib_modes = np.array(self._modes)[self._id_array[:, 1] > 0].tolist()
ok_modes = np.array(self._modes)[self._id_array[:, 2] > 0].tolist()
ik_modes = np.array(self._modes)[self._id_array[:, 3] > 0].tolist()
return tuple(ob_modes), tuple(ib_modes), tuple(ok_modes), tuple(ik_modes)

@classmethod
def _from_data(cls, id_array: np.ndarray, modes: list[int], mask=None):
r"""Private class method to initialize Wires object from the given data."""
w = cls()
w._id_array = id_array
w._modes = modes
w._mask = mask if mask is not None else np.ones_like(w._id_array)
return w

def _view(self, masked_rows: tuple[int, ...] = (), masked_cols: tuple[int, ...] = ()) -> Wires:
r"""A masked view of this Wires object."""
w = self._from_data(self._id_array, self._modes, self._mask.copy())
w._mask[masked_rows, :] = -1
w._mask[:, masked_cols] = -1
return w

def _mode(self, mode: int) -> np.ndarray:
"A slice of the id_array matrix at the given mode."
return np.maximum(0, self.id_array[[self._modes.index(mode)]])[0]

@property
def id_array(self) -> np.ndarray:
"The id_array of the available wires in the standard order (bra/ket x out/in x mode)."
return self._id_array * self._mask

@property
def ids(self) -> list[int]:
"The list of ids of the available wires in the standard order."
flat = self.id_array.T.ravel()
return flat[flat > 0].tolist()

@ids.setter
def ids(self, ids: list[int]):
"Sets the ids of the available wires."
if len(ids) != len(self.ids):
raise ValueError(f"wrong number of ids (expected {len(self.ids)}, got {len(ids)})")
self._id_array.flat[self.id_array.flatten() > 0] = ids

@property
def modes(self) -> list[int]:
"The set of modes spanned by the populated wires."
return [m for m in self._modes if any(self.id_array[self._modes.index(m)] > 0)]

@property
def indices(self) -> list[int]:
r"""Returns the array of indices of this subset in the standard order.
(bra/ket x out/in x mode). Use this to get the indices for bargmann contractions.
"""
flat = self.id_array.T.ravel()
flat = flat[flat != 0]
return np.where(flat > 0)[0].tolist()

@property
def input(self) -> Wires:
"A view of self without output wires"
return self._view(masked_cols=(0, 2))

@property
def output(self) -> Wires:
"A view of self without input wires"
return self._view(masked_cols=(1, 3))

@property
def ket(self) -> Wires:
"A view of self without bra wires"
return self._view(masked_cols=(0, 1))

@property
def bra(self) -> Wires:
"A view of self without ket wires"
return self._view(masked_cols=(2, 3))

@property
def adjoint(self) -> Wires:
r"""
The adjoint (ket <-> bra) of this wires object.
"""
return self._from_data(self._id_array[:, [2, 3, 0, 1]], self._modes, self._mask)

@property
def dual(self) -> Wires:
r"""
The dual (in <-> out) of this wires object.
"""
return self._from_data(self._id_array[:, [1, 0, 3, 2]], self._modes, self._mask)

def copy(self) -> Wires:
r"""A copy of this Wires object with new ids."""
w = Wires(*self._args())
w._mask = self._mask.copy()
return w

def __bool__(self) -> bool:
return len(self.ids) > 0

def __getitem__(self, modes: Iterable[int] | int) -> Wires:
"A view of this Wires object with wires only on the given modes."
modes = [modes] if isinstance(modes, int) else modes
idxs = tuple(list(self._modes).index(m) for m in set(self._modes).difference(modes))
return self._view(masked_rows=idxs)

def __lshift__(self, other: Wires) -> Wires:
return (other.dual >> self.dual).dual # how cool is this

@staticmethod
def _outin(si, so, oi, oo):
r"""Returns the output and input wires of the composite object made by connecting
two single-mode (ket or bra) objects like --|self|-- and --|other|--
At this stage we are guaranteed that the configurations `|self|-- |other|--` and
`--|self| --|other|` (which would be invalid) have already been excluded.
"""
if bool(so) == bool(oi): # if the inner wires are either both there or both not there
return np.array([oo, si], dtype=np.int64)
elif not si and not so: # no wires on self
return np.array([oo, oi], dtype=np.int64)
else: # no wires on other
return np.array([so, si], dtype=np.int64)

def __rshift__(self, other: Wires) -> Wires:
all_modes = sorted(set(self.modes) | set(other.modes))
new_id_array = np.zeros((len(all_modes), 4), dtype=np.int64)

for m in set(self.modes) & set(other.modes):
sob, sib, sok, sik = self._mode(m) # m-th row of self
oob, oib, ook, oik = other._mode(m) # m-th row of other

out_bra_issue = sob and oob and not oib
out_ket_issue = sok and ook and not oik
if out_bra_issue or out_ket_issue:
raise ValueError(f"Output wire overlap at mode {m}")
in_bra_issue = oib and sib and not sob
in_ket_issue = oik and sik and not sok
if in_bra_issue or in_ket_issue:
raise ValueError(f"Input wire overlap at mode {m}")

new_id_array[all_modes.index(m)] = np.hstack(
[self._outin(sib, sob, oib, oob), self._outin(sik, sok, oik, ook)]
)
for m in set(self.modes) - set(other.modes):
new_id_array[all_modes.index(m)] = self._mode(m)
for m in set(other.modes) - set(self.modes):
new_id_array[all_modes.index(m)] = other._mode(m)

return self._from_data(new_id_array, all_modes)

def __repr__(self) -> str:
ob_modes, ib_modes, ok_modes, ik_modes = self._args()
return f"Wires({ob_modes}, {ib_modes}, {ok_modes}, {ik_modes})"

def _repr_html_(self): # pragma: no cover
"A matrix plot of the id_array."
row_labels = map(str, self._modes)
col_labels = ["bra-out", "bra-in", "ket-out", "ket-in"]
array = np.abs(self.id_array) / (self.id_array + 1e-15)
idxs = (i for i in self.indices)
box_size = "60px" # Set the size of the squares
html = '<table style="border-collapse: collapse; border: 1px solid black;">'
# colors
active = "#5b9bd5"
inactive = "#d6e8f7"

# Add column headers
html += "<tr>"
for label in [""] + col_labels: # Add an empty string for the top-left cell
html += f'<th style="border: 1px solid black; padding: 5px;">{label}</th>'
html += "</tr>"

# Initialize rows with row labels
rows_html = [
f'<tr><td style="border: 1px solid black; padding: 5px;">{label}</td>'
for label in row_labels
]

# Add table cells (column by column)
for label, col in zip(col_labels, array.T):
for row_idx, value in enumerate(col):
color = (
"white"
if np.isclose(value, 0)
else (active if np.isclose(value, 1) else inactive)
)
cell_html = f'<td style="border: 1px solid black; padding: 5px;\
width: {box_size}px; height: {box_size}px; background-color:\
{color}; box-sizing: border-box;'
if color == active:
cell_html += (
f' text-align: center; vertical-align: middle;">{str(next(idxs))}</td>'
)
else:
cell_html += '"></td>'
rows_html[row_idx] += cell_html

# Close the rows and add them to the HTML table
for row_html in rows_html:
row_html += "</tr>"
html += row_html

html += "</table>"

try:
from IPython.core.display import (
HTML,
display,
)

display(HTML(html))
except ImportError as e:
raise ImportError(
"To display the wires in a jupyter notebook you need to `pip install IPython` first."
) from e
6 changes: 3 additions & 3 deletions mrmustard/training/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ def call(
orig_cost = np.array(optimizer.callback_history["orig_cost"][-1]).item()
obj_scalars[f"{obj_tag}/orig_cost"] = orig_cost
if self.cost_converter is not None:
obj_scalars[
f"{obj_tag}/{self.cost_converter.__name__}(orig_cost)"
] = self.cost_converter(orig_cost)
obj_scalars[f"{obj_tag}/{self.cost_converter.__name__}(orig_cost)"] = (
self.cost_converter(orig_cost)
)

for k, v in obj_scalars.items():
tf.summary.scalar(k, data=v, step=self.optimizer_step)
Expand Down
3 changes: 1 addition & 2 deletions mrmustard/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,4 @@
class Batch(Protocol[T_co]):
r"""Anything that can iterate over objects of type T_co."""

def __iter__(self) -> Iterator[T_co]:
...
def __iter__(self) -> Iterator[T_co]: ...
Loading
Loading