Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi representation #488

Closed
wants to merge 13 commits into from
148 changes: 104 additions & 44 deletions mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from mrmustard.math.parameters import Constant, Variable
from mrmustard.lab_dev.wires import Wires
from mrmustard.physics.triples import identity_Abc
from mrmustard.lab_dev.utils import BtoQ_mult_table

__all__ = ["CircuitComponent"]

Expand All @@ -60,6 +61,7 @@ class CircuitComponent:
a ``(modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket)``
where if any of the modes are out of order the representation
will be reordered.
multi_rep: A dictionary indicating what
name: The name of this component.
"""

Expand All @@ -69,6 +71,7 @@ def __init__(
self,
representation: Bargmann | Fock | None = None,
wires: Wires | Sequence[tuple[int]] | None = None,
multi_rep: dict | None = None,
name: str | None = None,
) -> None:
self._name = name
Expand Down Expand Up @@ -108,6 +111,11 @@ def __init__(
if self._representation:
self._representation = self._representation.reorder(tuple(perm))

if multi_rep:
self._multi_rep = multi_rep
else:
self._multi_rep = {key: None for key in self.wires.modes}

def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]:
"""
Inner serialization to be used by Circuit.serialize().
Expand Down Expand Up @@ -152,6 +160,15 @@ def _deserialize(cls, data: dict) -> CircuitComponent:

return cls(**data)

@property
def multi_rep(self) -> dict:
r"""
The multirepresentation of the object, is a dictionary from modes to either None,
'Q', or 'PS'.
None = Bargman, Q = Quandrature, PS = Phase Space
"""
return self._multi_rep

@property
def adjoint(self) -> CircuitComponent:
r"""
Expand All @@ -163,7 +180,9 @@ def adjoint(self) -> CircuitComponent:
kets = self.wires.ket.indices
rep = self.representation.reorder(kets + bras).conj() if self.representation else None

ret = CircuitComponent(rep, self.wires.adjoint, self.name)
ret = CircuitComponent(
rep, wires=self.wires.adjoint, multi_rep=self.multi_rep, name=self.name
)
ret.short_name = self.short_name
return ret

Expand All @@ -180,7 +199,7 @@ def dual(self) -> CircuitComponent:
ob = self.wires.bra.output.indices
rep = self.representation.reorder(ib + ob + ik + ok).conj() if self.representation else None

ret = CircuitComponent(rep, self.wires.dual, self.name)
ret = CircuitComponent(rep, wires=self.wires.dual, multi_rep=self.multi_rep, name=self.name)
ret.short_name = self.short_name

return ret
Expand Down Expand Up @@ -241,7 +260,13 @@ def representation(self) -> Representation | None:
r"""
A representation of this circuit component.
"""
return self._representation
from .circuit_components_utils import BtoQ

copy_of_self = self
for mode in self.modes:
if self.multi_rep[mode] == "Q":
copy_of_self = copy_of_self @ BtoQ([mode]).inverse()
return copy_of_self._representation

@property
def wires(self) -> Wires:
Expand All @@ -258,6 +283,7 @@ def from_bargmann(
modes_in_bra: Sequence[int] = (),
modes_out_ket: Sequence[int] = (),
modes_in_ket: Sequence[int] = (),
multi_rep: dict | None = None,
name: str | None = None,
) -> CircuitComponent:
r"""
Expand All @@ -276,7 +302,7 @@ def from_bargmann(
"""
repr = Bargmann(*triple)
wires = Wires(set(modes_out_bra), set(modes_in_bra), set(modes_out_ket), set(modes_in_ket))
return cls._from_attributes(repr, wires, name)
return cls._from_attributes(repr, wires=wires, multi_rep=multi_rep, name=name)

@classmethod
def from_quadrature(
Expand Down Expand Up @@ -315,13 +341,14 @@ def from_quadrature(
# NOTE: the representation is Bargmann here because we use the inverse of BtoQ on the B side
QQQQ = CircuitComponent._from_attributes(Bargmann(*triple), wires)
BBBB = QtoB_ib @ (QtoB_ik @ QQQQ @ QtoB_ok) @ QtoB_ob
return cls._from_attributes(BBBB.representation, wires, name)
return cls._from_attributes(BBBB.representation, wires=wires, name=name)

@classmethod
def _from_attributes(
cls,
representation: Representation,
wires: Wires,
multi_rep: dict | None = None,
name: str | None = None,
) -> CircuitComponent:
r"""
Expand Down Expand Up @@ -355,9 +382,10 @@ def _from_attributes(
ret = tp()
ret._name = name
ret._representation = representation
ret._multi_rep = multi_rep
ret._wires = wires
return ret
return CircuitComponent(representation, wires, name)
return CircuitComponent(representation, wires=wires, multi_rep=multi_rep, name=name)

def auto_shape(self, **_) -> tuple[int, ...]:
r"""
Expand Down Expand Up @@ -530,7 +558,7 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent:
ret = self._getitem_builtin(self.modes)
ret._representation = fock
except TypeError:
ret = self._from_attributes(fock, self.wires, self.name)
ret = self._from_attributes(fock, wires = self.wires, multi_rep = self.multi_rep, name = self.name)
if "manual_shape" in ret.__dict__:
del ret.manual_shape
return ret
Expand Down Expand Up @@ -641,15 +669,19 @@ def __add__(self, other: CircuitComponent) -> CircuitComponent:
raise ValueError("Cannot add components with different wires.")
rep = self.representation + other.representation
name = self.name if self.name == other.name else ""
return self._from_attributes(rep, self.wires, name)
return self._from_attributes(rep, wires=self.wires, name=name)

def __eq__(self, other) -> bool:
r"""
Whether this component is equal to another component.

Compares representations and wires, but not the other attributes (e.g. name and parameter set).
"""
return self.representation == other.representation and self.wires == other.wires
return (
self.representation == other.representation
and self.wires == other.wires
and self.multi_rep == other.multi_rep
)

def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent:
r"""
Expand All @@ -673,22 +705,25 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent:

wires_result, perm = self.wires @ other.wires
idx_z, idx_zconj = self._matmul_indices(other)
if type(self.representation) == type(other.representation):
self_rep = self.representation
other_rep = other.representation
else:
self_rep = self.to_bargmann().representation
other_rep = other.to_bargmann().representation
if type(self._representation) == type(other._representation):
self_rep = self._representation
other_rep = other._representation
else: # maybe we want to change back to .rep here?
self_rep = self.to_bargmann()._representation
other_rep = other.to_bargmann()._representation

rep = self_rep[idx_z] @ other_rep[idx_zconj]
rep = rep.reorder(perm) if perm else rep
return CircuitComponent._from_attributes(rep, wires_result, None)

return CircuitComponent._from_attributes(rep, wires=wires_result, name=None)

def __mul__(self, other: Scalar) -> CircuitComponent:
r"""
Implements the multiplication by a scalar from the right.
"""
return self._from_attributes(self.representation * other, self.wires, self.name)
return self._from_attributes(
self.representation * other, wires=self.wires, multi_rep=self.multi_rep, name=self.name
)

def __repr__(self) -> str:
repr = self.representation
Expand Down Expand Up @@ -746,38 +781,61 @@ def __rshift__(self, other: CircuitComponent | numbers.Number) -> CircuitCompone
>>> assert isinstance(Coherent([0], 1.0) >> Attenuator([0], 0.5), DM)
>>> assert isinstance(Coherent([0], 1.0) >> Coherent([0], 1.0).dual, complex)
"""
if hasattr(other, "__custom_rrshift__"):
return other.__custom_rrshift__(self)
from .circuit_components_utils import BtoQ

if isinstance(other, (numbers.Number, np.ndarray)):
return self * other
if not isinstance(other, BtoQ):
if hasattr(other, "__custom_rrshift__"):
return other.__custom_rrshift__(self)

s_k = self.wires.ket
s_b = self.wires.bra
o_k = other.wires.ket
o_b = other.wires.bra
if isinstance(other, (numbers.Number, np.ndarray)):
return self * other

only_ket = (not s_b and s_k) and (not o_b and o_k)
only_bra = (not s_k and s_b) and (not o_k and o_b)
both_sides = s_b and s_k and o_b and o_k
s_k = self.wires.ket
s_b = self.wires.bra
o_k = other.wires.ket
o_b = other.wires.bra

self_needs_bra = (not s_b and s_k) and (o_b and o_k)
self_needs_ket = (not s_k and s_b) and (o_b and o_k)
only_ket = (not s_b and s_k) and (not o_b and o_k)
only_bra = (not s_k and s_b) and (not o_k and o_b)
both_sides = s_b and s_k and o_b and o_k

other_needs_bra = (s_b and s_k) and (not o_b and o_k)
other_needs_ket = (s_b and s_k) and (not o_k and o_b)
self_needs_bra = (not s_b and s_k) and (o_b and o_k)
self_needs_ket = (not s_k and s_b) and (o_b and o_k)

if only_ket or only_bra or both_sides:
ret = self @ other
elif self_needs_bra or self_needs_ket:
ret = self.adjoint @ (self @ other)
elif other_needs_bra or other_needs_ket:
ret = (self @ other) @ other.adjoint
other_needs_bra = (s_b and s_k) and (not o_b and o_k)
other_needs_ket = (s_b and s_k) and (not o_k and o_b)

if only_ket or only_bra or both_sides:
ret = self @ other
elif self_needs_bra or self_needs_ket:
ret = self.adjoint @ (self @ other)
elif other_needs_bra or other_needs_ket:
ret = (self @ other) @ other.adjoint
else:
msg = f"``>>`` not supported between {self} and {other} because it's not clear "
msg += "whether or where to add bra wires. Use ``@`` instead and specify all the components."
raise ValueError(msg)

ret._multi_rep = {mode: None for mode in ret.modes}
for mode in ret.modes:
if mode in list(set(self.modes) & set(other.modes)):
if not self.multi_rep[mode]:
ret._multi_rep[mode] = self.multi_rep[mode]
elif not other.multi_rep[mode]:
ret._multi_rep[mode] = self.multi_rep[mode]
elif mode in self.modes:
ret._multi_rep[mode] = self.multi_rep[mode]
else:
ret._multi_rep[mode] = other.multi_rep[mode]

return self._rshift_return(ret)
else:
msg = f"``>>`` not supported between {self} and {other} because it's not clear "
msg += "whether or where to add bra wires. Use ``@`` instead and specify all the components."
raise ValueError(msg)
return self._rshift_return(ret)
for mode in other.modes:
if self.multi_rep[mode] == 'Q':
self._multi_rep[mode] = None
else:
self._multi_rep[mode] = 'Q'
return self

def __sub__(self, other: CircuitComponent) -> CircuitComponent:
r"""
Expand All @@ -787,13 +845,15 @@ def __sub__(self, other: CircuitComponent) -> CircuitComponent:
raise ValueError("Cannot subtract components with different wires.")
rep = self.representation - other.representation
name = self.name if self.name == other.name else ""
return self._from_attributes(rep, self.wires, name)
return self._from_attributes(rep, wires=self.wires, name=name)

def __truediv__(self, other: Scalar) -> CircuitComponent:
r"""
Implements the division by a scalar for circuit components.
"""
return self._from_attributes(self.representation / other, self.wires, self.name)
return self._from_attributes(
self.representation / other, wires=self.wires, multi_rep=self.multi_rep, name=self.name
)

def _ipython_display_(self):
# both reps might return None
Expand Down
1 change: 1 addition & 0 deletions mrmustard/lab_dev/circuit_components_utils/b_to_ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
representation=Bargmann.from_function(
fn=triples.displacement_map_s_parametrized_Abc, s=s, n_modes=len(modes)
),
multi_rep={mode: "BtoPS" for mode in modes},
name="BtoPS",
)
self._add_parameter(Constant(s, "s"))
Expand Down
1 change: 1 addition & 0 deletions mrmustard/lab_dev/circuit_components_utils/b_to_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
modes_out=modes,
modes_in=modes,
representation=repr,
multi_rep={mode: "BtoQ" for mode in modes},
name="BtoQ",
)
self._add_parameter(Constant(phi, "phi"))
Loading
Loading