Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the derivative of MottonenStatePreparation where possible #5774

Merged
merged 11 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@

<h3>Bug fixes 🐛</h3>

* Fixes a bug where `MottonenStatePreparation` produces wrong derivatives at special parameter values.
[(#5774)](https://github.com/PennyLaneAI/pennylane/pull/5774)

* Fixes a bug where fractional powers and adjoints of operators were commuted, which is
not well-defined/correct in general. Adjoints of fractional powers can no longer be evaluated.
[(#5835)](https://github.com/PennyLaneAI/pennylane/pull/5835)
Expand Down
21 changes: 16 additions & 5 deletions pennylane/templates/state_preparations/mottonen.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def compute_theta(alpha):
(tensor_like): rotation angles theta
"""
ln = alpha.shape[-1]
k = np.log2(ln)

M_trans = np.zeros(shape=(ln, ln))
for i in range(len(M_trans)):
Expand All @@ -91,7 +90,7 @@ def compute_theta(alpha):

theta = qml.math.transpose(qml.math.dot(M_trans, qml.math.transpose(alpha)))

return theta / 2**k
return theta / ln


def _apply_uniform_rotation_dagger(gate, alpha, control_wires, target_wire):
Expand Down Expand Up @@ -124,7 +123,11 @@ def _apply_uniform_rotation_dagger(gate, alpha, control_wires, target_wire):
gray_code_rank = len(control_wires)

if gray_code_rank == 0:
if qml.math.is_abstract(theta) or qml.math.all(theta[..., 0] != 0.0):
if (
qml.math.is_abstract(theta)
or qml.math.requires_grad(theta)
or qml.math.all(theta[..., 0] != 0.0)
):
op_list.append(gate(theta[..., 0], wires=[target_wire]))
return op_list

Expand All @@ -137,7 +140,11 @@ def _apply_uniform_rotation_dagger(gate, alpha, control_wires, target_wire):
]

for i, control_index in enumerate(control_indices):
if qml.math.is_abstract(theta) or qml.math.all(theta[..., i] != 0.0):
if (
qml.math.is_abstract(theta)
or qml.math.requires_grad(theta)
or qml.math.all(theta[..., i] != 0.0)
):
op_list.append(gate(theta[..., i], wires=[target_wire]))
op_list.append(qml.CNOT(wires=[control_wires[control_index], target_wire]))
return op_list
Expand Down Expand Up @@ -366,7 +373,11 @@ def compute_decomposition(state_vector, wires): # pylint: disable=arguments-dif
op_list.extend(_apply_uniform_rotation_dagger(qml.RY, alpha_y_k, control, target))

# If necessary, apply inverse z rotation cascade to prepare correct phases of amplitudes
if qml.math.is_abstract(omega) or not qml.math.allclose(omega, 0):
if (
qml.math.is_abstract(omega)
or qml.math.requires_grad(omega)
or not qml.math.allclose(omega, 0)
):
for k in range(len(wires_reverse), 0, -1):
alpha_z_k = _get_alpha_z(omega, len(wires_reverse), k)
control = wires_reverse[k:]
Expand Down
245 changes: 98 additions & 147 deletions tests/templates/test_state_preparations/test_mottonen_state_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,74 +67,60 @@ def test_get_alpha_y(self, current_qubit, expected, tol):
assert np.allclose(res, expected, atol=tol)


# fmt: off
fixed_states = (
[
-0.17133152 - 0.18777771j, 0.00240643 - 0.40704011j, 0.18684538 - 0.36315606j,
-0.07096948 + 0.104501j, 0.30357755 - 0.23831927j, -0.38735106 + 0.36075556j,
0.12351096 - 0.0539908j, 0.27942828 - 0.24810483j,
],
[
-0.29972867 + 0.04964242j, -0.28309418 + 0.09873227j, 0.00785743 - 0.37560696j,
-0.3825148 + 0.00674343j, -0.03008048 + 0.31119167j, 0.03666351 - 0.15935903j,
-0.25358831 + 0.35461265j, -0.32198531 + 0.33479292j,
],
[
-0.39340123 + 0.05705932j, 0.1980509 - 0.24234781j, 0.27265585 - 0.0604432j,
-0.42641249 + 0.25767258j, 0.40386614 - 0.39925987j, 0.03924761 + 0.13193724j,
-0.06059103 - 0.01753834j, 0.21707136 - 0.15887973j,
],
[
-1.33865287e-01 + 0.09802308j, 1.25060033e-01 + 0.16087698j, -4.14678130e-01 - 0.00774832j,
1.10121136e-01 + 0.37805482j, -3.21284864e-01 + 0.21521063j, -2.23121454e-04 + 0.28417422j,
5.64131205e-02 + 0.38135286j, 2.32694503e-01 + 0.41331133j,
],
)
# fmt: on
decomposition_test_cases = [
([1, 0], 0, np.eye(8)[0]),
([1, 0], [0], np.eye(8)[0]),
([1, 0], [1], np.eye(8)[0]),
([1, 0], [2], np.eye(8)[0]),
([0, 1], [0], np.eye(8)[4]),
([0, 1], [1], np.eye(8)[2]),
([0, 1], [2], np.eye(8)[1]),
([0, 1, 0, 0], [0, 1], np.eye(8)[2]),
([0, 0, 0, 1], [0, 2], np.eye(8)[5]),
([0, 0, 0, 1], [1, 2], np.eye(8)[3]),
(np.eye(8)[0], [0, 1, 2], np.eye(8)[0]),
(1j * np.eye(8)[4], [0, 1, 2], 1j * np.eye(8)[4]),
(x := np.array([1, 0, 0, 0, 1, 1j, -1, 0]) / 2, [0, 1, 2], x),
(x := np.array([1, 0, 0, 0, 2j, 2j, 0, 0]) / 3, [0, 1, 2], x),
(x := np.array([2, 0, 0, 0, 1, 0, 0, 2]) / 3, [0, 1, 2], x),
(x := np.array([1, 1j, 1, -1j, 1, 1, 1, 1j]) / np.sqrt(8), [0, 1, 2], x),
(fixed_states[0], [0, 1, 2], fixed_states[0]),
(fixed_states[1], [0, 1, 2], fixed_states[1]),
(fixed_states[2], [0, 1, 2], fixed_states[2]),
(fixed_states[3], [0, 1, 2], fixed_states[3]),
(x := np.array([1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0]), [0, 1, 2], x),
(np.array([1 / 2, 0, 1j / 2, 1j / np.sqrt(2)]), [0, 1], x),
]

dwierichs marked this conversation as resolved.
Show resolved Hide resolved

class TestDecomposition:
"""Tests that the template defines the correct decomposition."""

# fmt: off
@pytest.mark.parametrize("state_vector,wires,target_state", [
([1, 0], 0, [1, 0, 0, 0, 0, 0, 0, 0]),
([1, 0], [0], [1, 0, 0, 0, 0, 0, 0, 0]),
([1, 0], [1], [1, 0, 0, 0, 0, 0, 0, 0]),
([1, 0], [2], [1, 0, 0, 0, 0, 0, 0, 0]),
([0, 1], [0], [0, 0, 0, 0, 1, 0, 0, 0]),
([0, 1], [1], [0, 0, 1, 0, 0, 0, 0, 0]),
([0, 1], [2], [0, 1, 0, 0, 0, 0, 0, 0]),
([0, 1, 0, 0], [0, 1], [0, 0, 1, 0, 0, 0, 0, 0]),
([0, 0, 0, 1], [0, 2], [0, 0, 0, 0, 0, 1, 0, 0]),
([0, 0, 0, 1], [1, 2], [0, 0, 0, 1, 0, 0, 0, 0]),
([1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 2], [1, 0, 0, 0, 0, 0, 0, 0]),
([0, 0, 0, 0, 1j, 0, 0, 0], [0, 1, 2], [0, 0, 0, 0, 1j, 0, 0, 0]),
([1 / 2, 0, 0, 0, 1 / 2, 1j / 2, -1 / 2, 0], [0, 1, 2], [1 / 2, 0, 0, 0, 1 / 2, 1j / 2, -1 / 2, 0]),
([1 / 3, 0, 0, 0, 2j / 3, 2j / 3, 0, 0], [0, 1, 2], [1 / 3, 0, 0, 0, 2j / 3, 2j / 3, 0, 0]),
([2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3], [0, 1, 2], [2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3]),
(
[1 / np.sqrt(8), 1j / np.sqrt(8), 1 / np.sqrt(8), -1j / np.sqrt(8), 1 / np.sqrt(8),
1 / np.sqrt(8), 1 / np.sqrt(8), 1j / np.sqrt(8)],
[0, 1, 2],
[1 / np.sqrt(8), 1j / np.sqrt(8), 1 / np.sqrt(8), -1j / np.sqrt(8), 1 / np.sqrt(8),
1 / np.sqrt(8), 1 / np.sqrt(8), 1j / np.sqrt(8)],
),
(
[-0.17133152 - 0.18777771j, 0.00240643 - 0.40704011j, 0.18684538 - 0.36315606j, -0.07096948 + 0.104501j,
0.30357755 - 0.23831927j, -0.38735106 + 0.36075556j, 0.12351096 - 0.0539908j,
0.27942828 - 0.24810483j],
[0, 1, 2],
[-0.17133152 - 0.18777771j, 0.00240643 - 0.40704011j, 0.18684538 - 0.36315606j, -0.07096948 + 0.104501j,
0.30357755 - 0.23831927j, -0.38735106 + 0.36075556j, 0.12351096 - 0.0539908j,
0.27942828 - 0.24810483j],
),
(
[-0.29972867 + 0.04964242j, -0.28309418 + 0.09873227j, 0.00785743 - 0.37560696j,
-0.3825148 + 0.00674343j, -0.03008048 + 0.31119167j, 0.03666351 - 0.15935903j,
-0.25358831 + 0.35461265j, -0.32198531 + 0.33479292j],
[0, 1, 2],
[-0.29972867 + 0.04964242j, -0.28309418 + 0.09873227j, 0.00785743 - 0.37560696j,
-0.3825148 + 0.00674343j, -0.03008048 + 0.31119167j, 0.03666351 - 0.15935903j,
-0.25358831 + 0.35461265j, -0.32198531 + 0.33479292j],
),
(
[-0.39340123 + 0.05705932j, 0.1980509 - 0.24234781j, 0.27265585 - 0.0604432j, -0.42641249 + 0.25767258j,
0.40386614 - 0.39925987j, 0.03924761 + 0.13193724j, -0.06059103 - 0.01753834j,
0.21707136 - 0.15887973j],
[0, 1, 2],
[-0.39340123 + 0.05705932j, 0.1980509 - 0.24234781j, 0.27265585 - 0.0604432j, -0.42641249 + 0.25767258j,
0.40386614 - 0.39925987j, 0.03924761 + 0.13193724j, -0.06059103 - 0.01753834j,
0.21707136 - 0.15887973j],
),
(
[-1.33865287e-01 + 0.09802308j, 1.25060033e-01 + 0.16087698j, -4.14678130e-01 - 0.00774832j,
1.10121136e-01 + 0.37805482j, -3.21284864e-01 + 0.21521063j, -2.23121454e-04 + 0.28417422j,
5.64131205e-02 + 0.38135286j, 2.32694503e-01 + 0.41331133j],
[0, 1, 2],
[-1.33865287e-01 + 0.09802308j, 1.25060033e-01 + 0.16087698j, -4.14678130e-01 - 0.00774832j,
1.10121136e-01 + 0.37805482j, -3.21284864e-01 + 0.21521063j, -2.23121454e-04 + 0.28417422j,
5.64131205e-02 + 0.38135286j, 2.32694503e-01 + 0.41331133j],
),
([1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0], [0, 1, 2],
[1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0]),
([1 / 2, 0, 1j / 2, 1j / np.sqrt(2)], [0, 1], [1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0]),
])
# fmt: on
@pytest.mark.parametrize("state_vector,wires,target_state", decomposition_test_cases)
def test_state_preparation(self, state_vector, wires, target_state):
"""Tests that the template produces correct states."""

Expand All @@ -147,71 +133,7 @@ def circuit():

assert np.allclose(state, target_state)

# fmt: off
@pytest.mark.parametrize("state_vector,wires,target_state", [
([1, 0], 0, [1, 0, 0, 0, 0, 0, 0, 0]),
([1, 0], [0], [1, 0, 0, 0, 0, 0, 0, 0]),
([1, 0], [1], [1, 0, 0, 0, 0, 0, 0, 0]),
([1, 0], [2], [1, 0, 0, 0, 0, 0, 0, 0]),
([0, 1], [0], [0, 0, 0, 0, 1, 0, 0, 0]),
([0, 1], [1], [0, 0, 1, 0, 0, 0, 0, 0]),
([0, 1], [2], [0, 1, 0, 0, 0, 0, 0, 0]),
([0, 1, 0, 0], [0, 1], [0, 0, 1, 0, 0, 0, 0, 0]),
([0, 0, 0, 1], [0, 2], [0, 0, 0, 0, 0, 1, 0, 0]),
([0, 0, 0, 1], [1, 2], [0, 0, 0, 1, 0, 0, 0, 0]),
([1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 2], [1, 0, 0, 0, 0, 0, 0, 0]),
([0, 0, 0, 0, 1j, 0, 0, 0], [0, 1, 2], [0, 0, 0, 0, 1j, 0, 0, 0]),
([1 / 2, 0, 0, 0, 1 / 2, 1j / 2, -1 / 2, 0], [0, 1, 2], [1 / 2, 0, 0, 0, 1 / 2, 1j / 2, -1 / 2, 0]),
([1 / 3, 0, 0, 0, 2j / 3, 2j / 3, 0, 0], [0, 1, 2], [1 / 3, 0, 0, 0, 2j / 3, 2j / 3, 0, 0]),
([2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3], [0, 1, 2], [2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3]),
(
[1 / np.sqrt(8), 1j / np.sqrt(8), 1 / np.sqrt(8), -1j / np.sqrt(8), 1 / np.sqrt(8),
1 / np.sqrt(8), 1 / np.sqrt(8), 1j / np.sqrt(8)],
[0, 1, 2],
[1 / np.sqrt(8), 1j / np.sqrt(8), 1 / np.sqrt(8), -1j / np.sqrt(8), 1 / np.sqrt(8),
1 / np.sqrt(8), 1 / np.sqrt(8), 1j / np.sqrt(8)],
),
(
[-0.17133152 - 0.18777771j, 0.00240643 - 0.40704011j, 0.18684538 - 0.36315606j, -0.07096948 + 0.104501j,
0.30357755 - 0.23831927j, -0.38735106 + 0.36075556j, 0.12351096 - 0.0539908j,
0.27942828 - 0.24810483j],
[0, 1, 2],
[-0.17133152 - 0.18777771j, 0.00240643 - 0.40704011j, 0.18684538 - 0.36315606j, -0.07096948 + 0.104501j,
0.30357755 - 0.23831927j, -0.38735106 + 0.36075556j, 0.12351096 - 0.0539908j,
0.27942828 - 0.24810483j],
),
(
[-0.29972867 + 0.04964242j, -0.28309418 + 0.09873227j, 0.00785743 - 0.37560696j,
-0.3825148 + 0.00674343j, -0.03008048 + 0.31119167j, 0.03666351 - 0.15935903j,
-0.25358831 + 0.35461265j, -0.32198531 + 0.33479292j],
[0, 1, 2],
[-0.29972867 + 0.04964242j, -0.28309418 + 0.09873227j, 0.00785743 - 0.37560696j,
-0.3825148 + 0.00674343j, -0.03008048 + 0.31119167j, 0.03666351 - 0.15935903j,
-0.25358831 + 0.35461265j, -0.32198531 + 0.33479292j],
),
(
[-0.39340123 + 0.05705932j, 0.1980509 - 0.24234781j, 0.27265585 - 0.0604432j, -0.42641249 + 0.25767258j,
0.40386614 - 0.39925987j, 0.03924761 + 0.13193724j, -0.06059103 - 0.01753834j,
0.21707136 - 0.15887973j],
[0, 1, 2],
[-0.39340123 + 0.05705932j, 0.1980509 - 0.24234781j, 0.27265585 - 0.0604432j, -0.42641249 + 0.25767258j,
0.40386614 - 0.39925987j, 0.03924761 + 0.13193724j, -0.06059103 - 0.01753834j,
0.21707136 - 0.15887973j],
),
(
[-1.33865287e-01 + 0.09802308j, 1.25060033e-01 + 0.16087698j, -4.14678130e-01 - 0.00774832j,
1.10121136e-01 + 0.37805482j, -3.21284864e-01 + 0.21521063j, -2.23121454e-04 + 0.28417422j,
5.64131205e-02 + 0.38135286j, 2.32694503e-01 + 0.41331133j],
[0, 1, 2],
[-1.33865287e-01 + 0.09802308j, 1.25060033e-01 + 0.16087698j, -4.14678130e-01 - 0.00774832j,
1.10121136e-01 + 0.37805482j, -3.21284864e-01 + 0.21521063j, -2.23121454e-04 + 0.28417422j,
5.64131205e-02 + 0.38135286j, 2.32694503e-01 + 0.41331133j],
),
([1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0], [0, 1, 2],
[1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0]),
([1 / 2, 0, 1j / 2, 1j / np.sqrt(2)], [0, 1], [1 / 2, 0, 0, 0, 1j / 2, 0, 1j / np.sqrt(2), 0]),
])
# fmt: on
@pytest.mark.parametrize("state_vector,wires,target_state", decomposition_test_cases)
def test_state_preparation_probability_distribution(
self, tol, state_vector, wires, target_state
):
Expand All @@ -224,32 +146,32 @@ def circuit():
qml.expval(qml.PauliZ(0)),
qml.expval(qml.PauliZ(1)),
qml.expval(qml.PauliZ(2)),
qml.state(),
qml.probs(),
)

results = circuit()

state = results[-1].ravel()
probabilities = results[-1].ravel()

probabilities = np.abs(state) ** 2
target_probabilities = np.abs(target_state) ** 2

assert np.allclose(probabilities, target_probabilities, atol=tol, rtol=0)

# fmt: off
@pytest.mark.parametrize("state_vector, n_wires", [
([1 / 2, 1 / 2, 1 / 2, 1 / 2], 2),
([1, 0, 0, 0], 2),
([0, 1, 0, 0], 2),
([0, 0, 0, 1], 2),
([0, 1, 0, 0, 0, 0, 0, 0], 3),
([0, 0, 0, 0, 1, 0, 0, 0], 3),
([2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3], 3),
([1 / 2, 0, 0, 0, 1 / 2, 1 / 2, 1 / 2, 0], 3),
([1 / 3, 0, 0, 0, 2 / 3, 2 / 3, 0, 0], 3),
([2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3], 3),
])
# fmt: on
@pytest.mark.parametrize(
"state_vector, n_wires",
[
([1 / 2, 1 / 2, 1 / 2, 1 / 2], 2),
([1, 0, 0, 0], 2),
([0, 1, 0, 0], 2),
([0, 0, 0, 1], 2),
([0, 1, 0, 0, 0, 0, 0, 0], 3),
([0, 0, 0, 0, 1, 0, 0, 0], 3),
([2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3], 3),
([1 / 2, 0, 0, 0, 1 / 2, 1 / 2, 1 / 2, 0], 3),
([1 / 3, 0, 0, 0, 2 / 3, 2 / 3, 0, 0], 3),
([2 / 3, 0, 0, 0, 1 / 3, 0, 0, 2 / 3], 3),
],
)
def test_RZ_skipped(self, mocker, state_vector, n_wires):
"""Tests that the cascade of RZ gates is skipped for real-valued states."""

Expand Down Expand Up @@ -492,3 +414,32 @@ def circuit(state):
expected = np.zeros(8)
expected[0] = 1.0
assert qml.math.allclose(actual, expected)


@pytest.mark.jax
@pytest.mark.parametrize("shots, atol", [(None, 0.005), (1000000, 0.05)])
dwierichs marked this conversation as resolved.
Show resolved Hide resolved
def test_jacobians_with_and_without_jit_match(shots, atol):
"""Test that the Jacobian of the circuit is the same with and without jit."""
import jax

dev = qml.device("default.qubit", shots=shots, seed=7890234)
dev_no_shots = qml.device("default.qubit", shots=None)

def circuit(coeffs):
qml.MottonenStatePreparation(coeffs, wires=[0, 1])
return qml.probs(wires=[0, 1])

circuit_fd = qml.QNode(circuit, dev, diff_method="finite-diff", h=0.05)
circuit_exact = qml.QNode(circuit, dev_no_shots)

params = jax.numpy.array([0.5, 0.5, 0.5, 0.5])
jac_exact_fn = jax.jacobian(circuit_exact)
jac_fn = jax.jacobian(circuit_fd)
jac_jit_fn = jax.jit(jac_fn)

jac_exact = jac_exact_fn(params)
jac = jac_fn(params)
jac_jit = jac_jit_fn(params)

assert qml.math.allclose(jac_exact, jac_jit, atol=atol)
assert qml.math.allclose(jac, jac_jit, atol=atol)
22 changes: 15 additions & 7 deletions tests/templates/test_subroutines/test_qubitization.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,16 +223,22 @@ def test_qnode_autograd(self):
assert np.allclose(res, self.exp_grad, atol=1e-5)

@pytest.mark.jax
@pytest.mark.parametrize(
"use_jit , shots",
((False, None), (True, None), (False, 50000)),
) # TODO: (True, 50000) fails because jax.jit on jax.grad does not work with AmplitudeEmbedding
@pytest.mark.parametrize("use_jit", (False, True))
@pytest.mark.parametrize("shots", (None, 50000))
@pytest.mark.parametrize("device", ["default.qubit", "default.qubit.legacy"])
def test_qnode_jax(self, shots, use_jit, device):
""" "Test that the QNode executes and is differentiable with JAX. The shots
argument controls whether autodiff or parameter-shift gradients are used."""
import jax

# TODO: Allow the following cases once their underlying issues are fixed:
# (True, 50000): jax.jit on jax.grad does not work with AmplitudeEmbedding currently
# (False, 50000): Since #5774, the decomposition of AmplitudeEmbedding triggered by
# param-shift includes a GlobalPhase always. GlobalPhase will only be
# param-shift-compatible again once #5620 is merged in.
if shots is not None:
pytest.xfail()

jax.config.update("jax_enable_x64", True)

if device == "default.qubit":
Expand All @@ -256,14 +262,16 @@ def test_qnode_jax(self, shots, use_jit, device):
assert np.allclose(jac, self.exp_grad, atol=0.05)

@pytest.mark.torch
@pytest.mark.parametrize(
"shots", [None]
) # TODO: finite shots fails because Prod is not currently differentiable.
@pytest.mark.parametrize("shots", [None, 50000])
def test_qnode_torch(self, shots):
""" "Test that the QNode executes and is differentiable with Torch. The shots
argument controls whether autodiff or parameter-shift gradients are used."""
import torch

# TODO: finite shots fails because Prod is not currently differentiable.
if shots is not None:
pytest.xfail()

dev = qml.device("default.qubit", shots=shots, seed=10)
diff_method = "backprop" if shots is None else "parameter-shift"
qnode = qml.QNode(self.circuit, dev, interface="torch", diff_method=diff_method)
Expand Down
Loading