Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
apchytr committed Sep 19, 2024
1 parent 1ddafdb commit 5e21bfd
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 8 deletions.
6 changes: 4 additions & 2 deletions mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,10 @@ def adjoint(self) -> CircuitComponent:
bras = self.wires.bra.indices
kets = self.wires.ket.indices
rep = self.representation.reorder(kets + bras).conj() if self.representation else None

ret = CircuitComponent(rep, self.wires.adjoint, self.name)
ret.short_name = self.short_name
for param in self.parameter_set.all_parameters.values():
ret._add_parameter(param)
return ret

@property
Expand All @@ -184,7 +185,8 @@ def dual(self) -> CircuitComponent:

ret = CircuitComponent(rep, self.wires.dual, self.name)
ret.short_name = self.short_name

for param in self.parameter_set.all_parameters.values():
ret._add_parameter(param)
return ret

@cached_property
Expand Down
9 changes: 9 additions & 0 deletions mrmustard/math/parameter_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,12 @@ def __bool__(self) -> bool:
if self._constants or self._variables:
return True
return False

def __eq__(self, other: Any) -> bool:
if not isinstance(other, ParameterSet):
return False
return (
self._names == other._names
and self._constants == other._constants
and self._variables == other._variables
)
6 changes: 6 additions & 0 deletions tests/test_lab_dev/test_circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,16 @@ def test_adjoint(self):
assert isinstance(d1_adj, CircuitComponent)
assert d1_adj.name == d1.name
assert d1_adj.wires == d1.wires.adjoint
assert d1_adj.parameter_set == d1.parameter_set
assert (
d1_adj.representation == d1.representation.conj()
) # this holds for the Dgate but not in general

d1_adj_adj = d1_adj.adjoint
assert isinstance(d1_adj_adj, CircuitComponent)
assert d1_adj_adj.wires == d1.wires
assert d1_adj_adj.parameter_set == d1_adj.parameter_set
assert d1_adj_adj.parameter_set == d1.parameter_set
assert d1_adj_adj.representation == d1.representation

def test_dual(self):
Expand All @@ -138,11 +141,14 @@ def test_dual(self):
assert isinstance(d1_dual, CircuitComponent)
assert d1_dual.name == d1.name
assert d1_dual.wires == d1.wires.dual
assert d1_dual.parameter_set == d1.parameter_set
assert (vac >> d1 >> d1_dual).representation == vac.representation
assert (vac >> d1_dual >> d1).representation == vac.representation

d1_dual_dual = d1_dual.dual
assert isinstance(d1_dual_dual, CircuitComponent)
assert d1_dual_dual.parameter_set == d1_dual.parameter_set
assert d1_dual_dual.parameter_set == d1.parameter_set
assert d1_dual_dual.wires == d1.wires
assert d1_dual_dual.representation == d1.representation

Expand Down
12 changes: 6 additions & 6 deletions tests/test_lab_dev/test_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ def test_repr(self):

circ2 = Circuit([vac012, s01, bs01, bs12, cc, n12.dual])
r2 = ""
r2 += "\nmode 0: ◖Vac◗──S(0.0,2.0)──╭•──────────────────────────CC──|3)="
r2 += "\nmode 1: ◖Vac◗──S(1.0,3.0)──╰BS(0.0,0.0)──╭•────────────CC──|3)="
r2 += "\nmode 2: ◖Vac◗────────────────────────────╰BS(0.0,0.0)──────────"
r2 += "\nmode 0: ◖Vac◗──S(0.0,2.0)──╭•──────────────────────────CC──|3)=(3,3)"
r2 += "\nmode 1: ◖Vac◗──S(1.0,3.0)──╰BS(0.0,0.0)──╭•────────────CC──|3)=(3,3)"
r2 += "\nmode 2: ◖Vac◗────────────────────────────╰BS(0.0,0.0)───────────────"
assert repr(circ2) == r2 + "\n\n"

circ3 = Circuit([bs01, bs01, bs01, bs01, bs01, bs01, bs01, bs01, bs01, bs01, bs01])
Expand All @@ -213,9 +213,9 @@ def test_repr(self):

circ4 = Circuit([vac01, s01, vac2, bs01, bs12, n2.dual, cc, n12.dual])
r4 = ""
r4 += "\nmode 0: ◖Vac◗──S(0.0,2.0)──╭•──────────────────────────CC────|3)="
r4 += "\nmode 1: ◖Vac◗──S(1.0,3.0)──╰BS(0.0,0.0)──╭•────────────CC────|3)="
r4 += "\nmode 2: ◖Vac◗─────────────────────╰BS(0.0,0.0)──|3)= "
r4 += "\nmode 0: ◖Vac◗──S(0.0,2.0)──╭•──────────────────────────CC─────────|3)=(3,3)"
r4 += "\nmode 1: ◖Vac◗──S(1.0,3.0)──╰BS(0.0,0.0)──╭•────────────CC─────────|3)=(3,3)"
r4 += "\nmode 2: ◖Vac◗─────────────────────╰BS(0.0,0.0)──|3)=(3,3) "
assert repr(circ4) == r4 + "\n\n"

circ5 = Circuit() >> vac1 >> bs01 >> vac1.dual >> vac1 >> bs01 >> vac1.dual
Expand Down
28 changes: 28 additions & 0 deletions tests/test_math/test_parameter_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,34 @@ def test_to_string(self):
assert ps.to_string(3) == "1.234, 2.346, 3.457"
assert ps.to_string(10) == "1.2345, 2.3456, 3.4567"

def test_eq(self):
const1 = Constant(1, "c1")
const2 = Constant([2, 3, 4], "c2")
var1 = Variable(5, "v1")
var2 = Variable([6, 7, 8], "v2")

ps1 = ParameterSet()
ps1.add_parameter(const1)
ps1.add_parameter(const2)
ps1.add_parameter(var1)
ps1.add_parameter(var2)

assert ps1 != 1.0

ps2 = ParameterSet()
ps2.add_parameter(const1)
ps2.add_parameter(const2)
ps2.add_parameter(var1)
ps2.add_parameter(var2)

assert ps1 == ps2

ps3 = ParameterSet()
ps3.add_parameter(const1)
ps3.add_parameter(var1)

assert ps1 != ps3

def test_get_item(self):
const1 = Constant(1, "c1")
const2 = Constant([2, 3, 4], "c2")
Expand Down

0 comments on commit 5e21bfd

Please sign in to comment.