Skip to content

Commit

Permalink
Updating qml.equal
Browse files Browse the repository at this point in the history
  • Loading branch information
mudit2812 committed Jul 27, 2023
1 parent aee9c1d commit e9571b8
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 24 deletions.
21 changes: 2 additions & 19 deletions pennylane/circuit_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,7 @@ def ancestors(self, ops):
Returns:
list[Operator]: ancestors of the given operators
"""
# anc = set(
# self._graph.get_node_data(n)
# for n in set().union(
# # rx.ancestors() returns node indexes instead of node-values
# *(rx.ancestors(self._graph, self._indices[id(o)]) for o in ops)
# )
# )
# return anc - set(ops)
# rx.ancestors() returns node indexes instead of node-values
# rx.ancestors() returns node indices instead of node-values
all_indices = set().union(*(rx.ancestors(self._graph, self._indices[id(o)]) for o in ops))
double_op_indices = set(self._indices[id(o)] for o in ops)
ancestor_indices = all_indices - double_op_indices
Expand All @@ -322,15 +314,7 @@ def descendants(self, ops):
Returns:
list[Operator]: descendants of the given operators
"""
# des = set(
# self._graph.get_node_data(n)
# for n in set().union(
# # rx.descendants() returns node indexes instead of node-values
# *(rx.descendants(self._graph, self._indices[id(o)]) for o in ops)
# )
# )
# return des - set(ops)
# rx.descendants() returns node indexes instead of node-values
# rx.descendants() returns node indices instead of node-values
all_indices = set().union(*(rx.descendants(self._graph, self._indices[id(o)]) for o in ops))
double_op_indices = set(self._indices[id(o)] for o in ops)
ancestor_indices = all_indices - double_op_indices
Expand Down Expand Up @@ -418,7 +402,6 @@ def parametrized_layers(self):

# check if any of the dependents are in the
# currently assembled layer
# if set(current.ops) & sub:
if any(o1 is o2 for o1 in current.ops for o2 in sub):
# operator depends on current layer, start a new layer
current = Layer([], [])
Expand Down
10 changes: 9 additions & 1 deletion pennylane/ops/functions/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
This module contains the qml.equal function.
"""
# pylint: disable=too-many-arguments,too-many-return-statements
from collections.abc import Iterable
from functools import singledispatch
from typing import Union
import pennylane as qml
Expand Down Expand Up @@ -378,7 +379,14 @@ def _equal_shadow_measurements(op1: ShadowExpvalMP, op2: ShadowExpvalMP, **kwarg
"""Determine whether two ShadowExpvalMP objects are equal"""

wires_match = op1.wires == op2.wires
H_match = op1.H == op2.H

if isinstance(op1.H, Operator) and isinstance(op2.H, Operator):
H_match = qml.equal(op1.H, op2.H)
elif isinstance(op1.H, Iterable) and isinstance(op2.H, Iterable):
H_match = all(qml.equal(o1, o2) for o1, o2 in zip(op1.H, op2.H))
else:
return False

k_match = op1.k == op2.k

return wires_match and H_match and k_match
1 change: 0 additions & 1 deletion pennylane/transforms/metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,6 @@ def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire):
tapes = []
ids = []
# Exclude the backwards cone of layer_i from the backwards cone of layer_j
# ops_between_cgens = [op for op in layer_j.pre_ops if op not in layer_i.pre_ops]
ops_between_cgens = [
op1 for op1 in layer_j.pre_ops if not any(op1 is op2 for op2 in layer_i.pre_ops)
]
Expand Down
7 changes: 4 additions & 3 deletions tests/ops/functions/test_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@
qml.expval(qml.PauliX(1)),
qml.probs(wires=1),
qml.probs(wires=0),
qml.probs(qml.PauliZ(0)),
qml.probs(qml.PauliZ(1)),
qml.probs(op=qml.PauliZ(0)),
qml.probs(op=qml.PauliZ(1)),
qml.state(),
qml.vn_entropy(wires=0),
qml.vn_entropy(wires=0, log_base=np.e),
Expand All @@ -151,7 +151,8 @@
qml.shadow_expval(
H=qml.Hamiltonian(
[1.0, 1.0], [qml.PauliX(0) @ qml.PauliX(1), qml.PauliZ(0) @ qml.PauliZ(1)]
)
),
k=3
),
ExpectationMP(eigvals=[1, -1]),
ExpectationMP(eigvals=[1, 2]),
Expand Down

0 comments on commit e9571b8

Please sign in to comment.