diff --git a/.github/workflows/check_in_artifact.yml b/.github/workflows/check_in_artifact.yml index de24731f239..8f84adc0414 100644 --- a/.github/workflows/check_in_artifact.yml +++ b/.github/workflows/check_in_artifact.yml @@ -56,29 +56,39 @@ on: required: false type: string default: '41898282+github-actions[bot]@users.noreply.github.com' + merge_multiple: + description: | + When multiple artifacts are matched, this changes the behavior of the destination directories. + If true, the downloaded artifacts will be in the same directory specified by path. + If false, the downloaded artifacts will be extracted into individual named directories within the specified path. + Optional. Default is 'true' + required: false + type: boolean + default: true jobs: check_in_artifact: runs-on: ubuntu-latest - + steps: - name: Checkout master uses: actions/checkout@v4 with: fetch-depth: ${{ inputs.master_branch_fetch_depth }} ref: master - + - name: Download artifacts uses: actions/download-artifact@v4 with: pattern: ${{ inputs.artifact_name_pattern }} path: ${{ inputs.artifact_save_path }} - + merge-multiple: ${{ inputs.merge_multiple }} + - name: Determine if changes have been made id: changed run: | echo "has_changes=$(git status --porcelain | wc -l | awk '{print $1}')" >> $GITHUB_OUTPUT - + - name: Prepare Commit Author if: steps.changed.outputs.has_changes != '0' env: @@ -88,7 +98,8 @@ jobs: run: | git config user.name "$COMMIT_AUTHOR_NAME" git config user.email "$COMMIT_AUTHOR_EMAIL" - + git stash push --all + echo "Checking if $HEAD_BRANCH_NAME exists..." if git ls-remote --exit-code origin "refs/heads/$HEAD_BRANCH_NAME"; then echo "$HEAD_BRANCH_NAME exists! Checking out..." @@ -97,17 +108,18 @@ jobs: echo "$HEAD_BRANCH_NAME does not exist! Creating..." git checkout -b "$HEAD_BRANCH_NAME" fi - + - name: Stage changes if: steps.changed.outputs.has_changes != '0' env: HEAD_BRANCH_NAME: ${{ inputs.pull_request_head_branch_name }} COMMIT_MESSAGE_DESCRIPTION: ${{ inputs.commit_message_description != '' && format('-> {0}', inputs.commit_message_description) || '' }} run: | + git checkout stash -- . git add ${{ inputs.artifact_save_path }} - git commit -m "Check in artifacts$COMMIT_MESSAGE_DESCRIPTION" + git commit --allow-empty -m "Check in artifacts$COMMIT_MESSAGE_DESCRIPTION" git push -f --set-upstream origin "$HEAD_BRANCH_NAME" - + # Create PR to master - name: Create Pull Request to master if: steps.changed.outputs.has_changes != '0' @@ -117,20 +129,21 @@ jobs: PR_TITLE: ${{ inputs.pull_request_title }} PR_BODY: ${{ inputs.pull_request_body }} run: | - EXISTING_CLOSED_PR="$(gh pr list --state closed --base master --head $HEAD_BRANCH_NAME --json url --jq '.[].url')" - - if [ -n "${EXISTING_CLOSED_PR}" ]; then - echo "Reopening PR... ${EXISTING_CLOSED_PR}" - gh pr reopen "${EXISTING_CLOSED_PR}" - exit 0 - fi - - EXISTING_PR="$(gh pr list --state open --base master --head $HEAD_BRANCH_NAME --json url --jq '.[].url')" - + EXISTING_PR="$(gh pr list --state open --base master --head $HEAD_BRANCH_NAME --json 'url' --jq '.[].url' | head -n 1)" + if [ -n "${EXISTING_PR}" ]; then echo "PR already exists ==> ${EXISTING_PR}" exit 0 else echo "Creating PR..." gh pr create --title "$PR_TITLE" --body "$PR_BODY" + exit 0 + fi + + EXISTING_CLOSED_PR="$(gh pr list --state closed --base master --head $HEAD_BRANCH_NAME --json 'mergedAt,url' --jq '.[] | select(.mergedAt == null).url' | head -n 1)" + + if [ -n "${EXISTING_CLOSED_PR}" ]; then + echo "Reopening PR... ${EXISTING_CLOSED_PR}" + gh pr reopen "${EXISTING_CLOSED_PR}" + exit 0 fi diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 71b47ff09a3..0711a0c292d 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -35,8 +35,6 @@ jobs: sphinx: if: github.event.pull_request.draft == false - env: - DEPS_BRANCH: bot/stable-deps-update needs: [determine_runner] runs-on: ${{ needs.determine_runner.outputs.runner_group }} steps: @@ -50,46 +48,32 @@ jobs: && pip3 install . && pip3 install openfermionpyscf && pip3 install aiohttp fsspec h5py - && pip freeze | grep -v 'file:///' > .github/stable/doc.txt.tmp build-command: "sphinx-build -b html . _build -W --keep-going" - - name: Prepare local repo - if: github.event.pull_request.head.repo.full_name == 'PennyLaneAI/pennylane' + - name: Freeze dependencies + shell: bash run: | - git fetch - git config user.name "GitHub Actions Bot" - git config user.email "<>" - if git ls-remote --exit-code origin "refs/heads/${{ env.DEPS_BRANCH }}"; then - git checkout "${{ env.DEPS_BRANCH }}" - else - git checkout master - git pull - git checkout -b "${{ env.DEPS_BRANCH }}" - fi - mv -f .github/stable/doc.txt.tmp .github/stable/doc.txt + pip freeze | grep -v 'file:///' > doc.txt + cat doc.txt - - name: Determine if changes have been made - if: github.event.pull_request.head.repo.full_name == 'PennyLaneAI/pennylane' - id: changed - run: | - echo "has_changes=$(git status --porcelain | wc -l | awk '{print $1}')" >> $GITHUB_OUTPUT - - - name: Stage changes - if: github.event.pull_request.head.repo.full_name == 'PennyLaneAI/pennylane' && steps.changed.outputs.has_changes != '0' - run: | - git add .github/stable/doc.txt - git commit -m "Update stable docs dependencies" - git push -f --set-upstream origin "${{ env.DEPS_BRANCH }}" - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Upload frozen requirements + uses: actions/upload-artifact@v4 + with: + name: frozen-doc + path: doc.txt - # Create PR to master - - name: Create pull request - if: github.event.pull_request.head.repo.full_name == 'PennyLaneAI/pennylane' && steps.changed.outputs.has_changes != '0' - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - function gh_pr_up() { - gh pr create $* || gh pr edit $* - } - gh_pr_up --title \"Update stable doc dependency files\" --body \".\" + upload-stable-deps: + if: github.event.pull_request.draft == false + needs: + - determine_runner + - sphinx + uses: ./.github/workflows/check_in_artifact.yml + with: + artifact_name_pattern: "frozen-doc" + artifact_save_path: ".github/stable/" + merge_multiple: true + pull_request_head_branch_name: bot/stable-deps-update + commit_message_description: Frozen Doc Dependencies Update + pull_request_title: Update stable dependency files + pull_request_body: | + Automatic update of stable requirement files to snapshot valid python environments. diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bca7aa27c59..8a607060930 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -39,7 +39,7 @@ jobs: upload-stable-deps: needs: tests uses: ./.github/workflows/check_in_artifact.yml - if: github.event_name == 'push' + if: github.event_name == 'schedule' with: artifact_name_pattern: "frozen-*" artifact_save_path: ".github/stable/" diff --git a/doc/_static/draw_mpl/per_wire_options.png b/doc/_static/draw_mpl/per_wire_options.png new file mode 100644 index 00000000000..d4baf342c97 Binary files /dev/null and b/doc/_static/draw_mpl/per_wire_options.png differ diff --git a/doc/_static/tape_mpl/per_wire_options.png b/doc/_static/tape_mpl/per_wire_options.png new file mode 100644 index 00000000000..529ae1f2aba Binary files /dev/null and b/doc/_static/tape_mpl/per_wire_options.png differ diff --git a/doc/code/qml_drawer.rst b/doc/code/qml_drawer.rst index 69a03567910..f4be0217146 100644 --- a/doc/code/qml_drawer.rst +++ b/doc/code/qml_drawer.rst @@ -85,4 +85,4 @@ Currently Available Styles +|pls|+|plw|+|skd|+ +-----+-----+-----+ +|sol|+|sod|+|def|+ -+-----+-----+-----+ \ No newline at end of file ++-----+-----+-----+ diff --git a/doc/conf.py b/doc/conf.py index 0e61021a768..99063fdf06d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -83,7 +83,7 @@ intersphinx_mapping = { "demo": ("https://pennylane.ai/qml/", None), - "catalyst": ("https://docs.pennylane.ai/projects/catalyst/en/stable", None) + "catalyst": ("https://docs.pennylane.ai/projects/catalyst/en/stable", None), } mathjax_path = ( @@ -113,6 +113,7 @@ # built documents. import pennylane + pennylane.Hamiltonian = pennylane.ops.Hamiltonian # The full version, including alpha/beta/rc tags. @@ -254,7 +255,6 @@ # Xanadu theme options (see theme.conf for more information). html_theme_options = { - "navbar_active_link": 4, "extra_copyrights": [ "TensorFlow, the TensorFlow logo, and any related marks are trademarks " "of Google Inc." ], diff --git a/doc/introduction/interfaces.rst b/doc/introduction/interfaces.rst index 7eb3d391600..61ff7aeb61f 100644 --- a/doc/introduction/interfaces.rst +++ b/doc/introduction/interfaces.rst @@ -34,6 +34,17 @@ a :class:`QNode `, e.g., If no interface is specified, PennyLane will automatically determine the interface based on provided arguments and keyword arguments. See ``qml.workflow.SUPPORTED_INTERFACES`` for a list of all accepted interface strings. +.. warning:: + + ``ComplexWarning`` messages may appear when running differentiable workflows involving both complex and float types, particularly + with certain interfaces. These warnings are common in backpropagation due to the nature of complex casting and do not + indicate an error in computation. If desired, you can suppress these warnings by adding the following code: + + .. code-block:: python + + import warnings + warnings.filterwarnings("ignore", category=np.ComplexWarning) + This will allow native numerical objects of the specified library (NumPy arrays, JAX arrays, Torch Tensors, or TensorFlow Tensors) to be passed as parameters to the quantum circuit. It also makes the gradients of the quantum circuit accessible to the classical library, enabling the diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 1bddb979d80..79530d90a51 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -3,7 +3,7 @@ # Release 0.40.0-dev (development release)

New features since last release

- + * A `DeviceCapabilities` data class is defined to contain all capabilities of the device's execution interface (i.e. its implementation of `Device.execute`). A TOML file can be used to define the capabilities of a device, and it can be loaded into a `DeviceCapabilities` object. [(#6407)](https://github.com/PennyLaneAI/pennylane/pull/6407) @@ -15,21 +15,30 @@ True ``` +

New API for Qubit Mixed

+ +* Added `qml.devices.qubit_mixed` module for mixed-state qubit device support. This module introduces: + + [(#6379)](https://github.com/PennyLaneAI/pennylane/pull/6379) An `apply_operation` helper function featuring: + + * Two density matrix contraction methods using `einsum` and `tensordot` + + * Optimized handling of special cases including: Diagonal operators, Identity operators, CX (controlled-X), Multi-controlled X gates, Grover operators + + [(#6503)](https://github.com/PennyLaneAI/pennylane/pull/6503) A submodule 'initialize_state' featuring a `create_initial_state` function for initializing a density matrix from `qml.StatePrep` operations or `qml.QubitDensityMatrix` operations +

Improvements 🛠

-

Other Improvements

+* Added support for the `wire_options` dictionary to customize wire line formatting in `qml.draw_mpl` circuit + visualizations, allowing global and per-wire customization with options like `color`, `linestyle`, and `linewidth`. + [(#6486)](https://github.com/PennyLaneAI/pennylane/pull/6486) -* Added `qml.devices.qubit_mixed` module for mixed-state qubit device support. This module introduces: - - A new API for mixed-state operations - - An `apply_operation` helper function featuring: - - Two density matrix contraction methods using `einsum` and `tensordot` - - Optimized handling of special cases including: - - Diagonal operators - - Identity operators - - CX (controlled-X) - - Multi-controlled X gates - - Grover operators - [(#6379)](https://github.com/PennyLaneAI/pennylane/pull/6379) +

Capturing and representing hybrid programs

+ +* `jax.vmap` can be captured with `qml.capture.make_plxpr` and is compatible with quantum circuits. + [(#6349)](https://github.com/PennyLaneAI/pennylane/pull/6349) + +

Other Improvements

* `qml.BasisRotation` template is now JIT compatible. [(#6019)](https://github.com/PennyLaneAI/pennylane/pull/6019) @@ -43,11 +52,21 @@

Documentation 📝

+* Add a warning message to Gradients and training documentation about ComplexWarnings + [(#6543)](https://github.com/PennyLaneAI/pennylane/pull/6543) +

Bug fixes 🐛

+* Fixed `Identity.__repr__` to return correct wires list. + [(#6506)](https://github.com/PennyLaneAI/pennylane/pull/6506) +

Contributors ✍️

This release contains contributions from (in alphabetical order): +Shiwen An Astral Cai, -Andrija Paurevic +Yushao Chen, +Pietropaolo Frisoni, +Andrija Paurevic, +Justin Pickering diff --git a/pennylane/_version.py b/pennylane/_version.py index 8a038f25c6b..4012d0f128d 100644 --- a/pennylane/_version.py +++ b/pennylane/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.40.0-dev8" +__version__ = "0.40.0-dev9" diff --git a/pennylane/capture/capture_diff.py b/pennylane/capture/capture_diff.py index 829a9516af1..627ff0243ec 100644 --- a/pennylane/capture/capture_diff.py +++ b/pennylane/capture/capture_diff.py @@ -25,28 +25,31 @@ @lru_cache -def create_non_jvp_primitive(): - """Create a primitive type ``NonJVPPrimitive``, which binds to JAX's JVPTrace - like a standard Python function and otherwise behaves like jax.core.Primitive. +def create_non_interpreted_prim(): + """Create a primitive type ``NonInterpPrimitive``, which binds to JAX's JVPTrace + and BatchTrace objects like a standard Python function and otherwise behaves like jax.core.Primitive. """ if not has_jax: # pragma: no cover return None # pylint: disable=too-few-public-methods - class NonJVPPrimitive(jax.core.Primitive): + class NonInterpPrimitive(jax.core.Primitive): """A subclass to JAX's Primitive that works like a Python function - when evaluating JVPTracers.""" + when evaluating JVPTracers and BatchTracers.""" def bind_with_trace(self, trace, args, params): - """Bind the ``NonJVPPrimitive`` with a trace. If the trace is a ``JVPTrace``, - binding falls back to a standard Python function call. Otherwise, the - bind call of JAX's standard Primitive is used.""" - if isinstance(trace, jax.interpreters.ad.JVPTrace): + """Bind the ``NonInterpPrimitive`` with a trace. + + If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call. + Otherwise, the bind call of JAX's standard Primitive is used.""" + if isinstance( + trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace) + ): return self.impl(*args, **params) return super().bind_with_trace(trace, args, params) - return NonJVPPrimitive + return NonInterpPrimitive @lru_cache @@ -57,7 +60,7 @@ def _get_grad_prim(): if not has_jax: # pragma: no cover return None - grad_prim = create_non_jvp_primitive()("grad") + grad_prim = create_non_interpreted_prim()("grad") grad_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init # pylint: disable=too-many-arguments @@ -89,7 +92,7 @@ def _get_jacobian_prim(): """Create a primitive for Jacobian computations. This primitive is used when capturing ``qml.jacobian``. """ - jacobian_prim = create_non_jvp_primitive()("jacobian") + jacobian_prim = create_non_interpreted_prim()("jacobian") jacobian_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init # pylint: disable=too-many-arguments diff --git a/pennylane/capture/capture_operators.py b/pennylane/capture/capture_operators.py index 8bea0be3f31..006298f8f23 100644 --- a/pennylane/capture/capture_operators.py +++ b/pennylane/capture/capture_operators.py @@ -20,7 +20,7 @@ import pennylane as qml -from .capture_diff import create_non_jvp_primitive +from .capture_diff import create_non_interpreted_prim has_jax = True try: @@ -103,7 +103,7 @@ def create_operator_primitive( if not has_jax: return None - primitive = create_non_jvp_primitive()(operator_type.__name__) + primitive = create_non_interpreted_prim()(operator_type.__name__) @primitive.def_impl def _(*args, **kwargs): diff --git a/pennylane/capture/capture_qnode.py b/pennylane/capture/capture_qnode.py index 491b9f3f6a4..de19f464701 100644 --- a/pennylane/capture/capture_qnode.py +++ b/pennylane/capture/capture_qnode.py @@ -14,25 +14,65 @@ """ This submodule defines a capture compatible call to QNodes. """ - from copy import copy from dataclasses import asdict from functools import lru_cache, partial +from numbers import Number +from warnings import warn import pennylane as qml +from pennylane.typing import TensorLike from .flatfn import FlatFn has_jax = True try: import jax - from jax.interpreters import ad + from jax.interpreters import ad, batching except ImportError: has_jax = False -def _get_shapes_for(*measurements, shots=None, num_device_wires=0): +def _is_scalar_tensor(arg) -> bool: + """Check if an argument is a scalar tensor-like object or a numeric scalar.""" + + if isinstance(arg, Number): + return True + + if isinstance(arg, TensorLike): + + if arg.size == 0: + raise ValueError("Empty tensors are not supported with jax.vmap.") + + if arg.shape == (): + return True + + if len(arg.shape) > 1: + raise ValueError( + "One argument has more than one dimension. " + "Currently, only single-dimension batching is supported." + ) + + return False + + +def _get_batch_shape(args, batch_dims): + """Calculate the batch shape for the given arguments and batch dimensions.""" + + if batch_dims is None: + return () + + input_shapes = [ + (arg.shape[batch_dim],) for arg, batch_dim in zip(args, batch_dims) if batch_dim is not None + ] + + return jax.lax.broadcast_shapes(*input_shapes) + + +def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=()): + """Calculate the abstract output shapes for the given measurements.""" + if jax.config.jax_enable_x64: # pylint: disable=no-member dtype_map = { float: jax.numpy.float64, @@ -53,7 +93,8 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0): for s in shots: for m in measurements: shape, dtype = m.aval.abstract_eval(shots=s, num_device_wires=num_device_wires) - shapes.append(jax.core.ShapedArray(shape, dtype_map.get(dtype, dtype))) + shapes.append(jax.core.ShapedArray(batch_shape + shape, dtype_map.get(dtype, dtype))) + return shapes @@ -66,21 +107,88 @@ def _get_qnode_prim(): # pylint: disable=too-many-arguments @qnode_prim.def_impl - def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts): + def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None): consts = args[:n_consts] - args = args[n_consts:] + non_const_args = args[n_consts:] def qfunc(*inner_args): return jax.core.eval_jaxpr(qfunc_jaxpr, consts, *inner_args) qnode = qml.QNode(qfunc, device, **qnode_kwargs) - return qnode._impl_call(*args, shots=shots) # pylint: disable=protected-access + + if batch_dims is not None: + # pylint: disable=protected-access + return jax.vmap(partial(qnode._impl_call, shots=shots), batch_dims)(*non_const_args) + + # pylint: disable=protected-access + return qnode._impl_call(*non_const_args, shots=shots) # pylint: disable=unused-argument @qnode_prim.def_abstract_eval - def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts): + def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None): + mps = qfunc_jaxpr.outvars - return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires)) + + return _get_shapes_for( + *mps, + shots=shots, + num_device_wires=len(device.wires), + batch_shape=_get_batch_shape(args[n_consts:], batch_dims), + ) + + def _qnode_batching_rule( + batched_args, + batch_dims, + qnode, + shots, + device, + qnode_kwargs, + qfunc_jaxpr, + n_consts, + ): + """ + Batching rule for the ``qnode`` primitive. + + This rule exploits the parameter broadcasting feature of the QNode to vectorize the circuit execution. + """ + + for i, (arg, batch_dim) in enumerate(zip(batched_args, batch_dims)): + + if _is_scalar_tensor(arg): + continue + + # Regardless of their shape, jax.vmap treats constants as scalars + # by automatically inserting `None` as the batch dimension. + if i < n_consts: + raise ValueError( + f"Constant argument at index {i} is not scalar. ", + "Only scalar constants are currently supported with jax.vmap.", + ) + + # To resolve this, we need to add more properties to the AbstractOperator + # class to indicate which operators support batching and check them here + if arg.size > 1 and batch_dim is None: + warn( + f"Argument at index {i} has more than 1 element but is not batched. " + "This may lead to unintended behavior or wrong results if the argument is provided " + "using parameter broadcasting to a quantum operation that supports batching.", + UserWarning, + ) + + result = qnode_prim.bind( + *batched_args, + shots=shots, + qnode=qnode, + device=device, + qnode_kwargs=qnode_kwargs, + qfunc_jaxpr=qfunc_jaxpr, + n_consts=n_consts, + batch_dims=batch_dims[n_consts:], + ) + + # The batch dimension is at the front (axis 0) for all elements in the result. + # JAX doesn't expose `out_axes` in the batching rule. + return result, (0,) * len(result) def make_zero(tan, arg): return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan @@ -91,6 +199,8 @@ def _qnode_jvp(args, tangents, **impl_kwargs): ad.primitive_jvps[qnode_prim] = _qnode_jvp + batching.primitive_batchers[qnode_prim] = _qnode_batching_rule + return qnode_prim @@ -173,12 +283,14 @@ def f(x): flat_fn = FlatFn(qfunc) qfunc_jaxpr = jax.make_jaxpr(flat_fn)(*args) + execute_kwargs = copy(qnode.execute_kwargs) mcm_config = asdict(execute_kwargs.pop("mcm_config")) qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config} qnode_prim = _get_qnode_prim() flat_args = jax.tree_util.tree_leaves(args) + res = qnode_prim.bind( *qfunc_jaxpr.consts, *flat_args, diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index 1e0c3f10c24..dfced4def45 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -17,7 +17,7 @@ from collections.abc import Callable import pennylane as qml -from pennylane.capture.capture_diff import create_non_jvp_primitive +from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.capture.flatfn import FlatFn from .compiler import ( @@ -407,7 +407,7 @@ def _get_while_loop_qfunc_prim(): import jax # pylint: disable=import-outside-toplevel - while_loop_prim = create_non_jvp_primitive()("while_loop") + while_loop_prim = create_non_interpreted_prim()("while_loop") while_loop_prim.multiple_results = True @while_loop_prim.def_impl @@ -622,7 +622,7 @@ def _get_for_loop_qfunc_prim(): import jax # pylint: disable=import-outside-toplevel - for_loop_prim = create_non_jvp_primitive()("for_loop") + for_loop_prim = create_non_interpreted_prim()("for_loop") for_loop_prim.multiple_results = True @for_loop_prim.def_impl diff --git a/pennylane/devices/qubit_mixed/__init__.py b/pennylane/devices/qubit_mixed/__init__.py index 61cf1e84ad8..9cc48d1ffab 100644 --- a/pennylane/devices/qubit_mixed/__init__.py +++ b/pennylane/devices/qubit_mixed/__init__.py @@ -24,3 +24,4 @@ apply_operation """ from .apply_operation import apply_operation +from .initialize_state import create_initial_state diff --git a/pennylane/devices/qubit_mixed/initialize_state.py b/pennylane/devices/qubit_mixed/initialize_state.py new file mode 100644 index 00000000000..26055909e61 --- /dev/null +++ b/pennylane/devices/qubit_mixed/initialize_state.py @@ -0,0 +1,68 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions to prepare a state.""" + +from collections.abc import Iterable + +import pennylane as qml +import pennylane.numpy as np +from pennylane import math + + +def create_initial_state( + # pylint: disable=unsupported-binary-operation + wires: qml.wires.Wires | Iterable, + prep_operation: qml.operation.StatePrepBase | qml.QubitDensityMatrix = None, + like: str = None, +): + r""" + Returns an initial state, defaulting to :math:`\ket{0}` if no state-prep operator is provided. + + Args: + wires (Union[Wires, Iterable]): The wires to be present in the initial state + prep_operation (Optional[StatePrepBase]): An operation to prepare the initial state + like (Optional[str]): The machine learning interface used to create the initial state. + Defaults to None + + Returns: + array: The initial density matrix (tensor form) of a circuit + """ + num_wires = len(wires) + num_axes = ( + 2 * num_wires + ) # we initialize the density matrix as the tensor form to keep compatibility with the rest of the module + if not prep_operation: + state = np.zeros((2,) * num_axes, dtype=complex) + state[(0,) * num_axes] = 1 + return math.asarray(state, like=like) + + if isinstance(prep_operation, qml.QubitDensityMatrix): + density_matrix = prep_operation.data + + else: + pure_state = prep_operation.state_vector(wire_order=list(wires)) + density_matrix = np.outer(pure_state, np.conj(pure_state)) + return _post_process(density_matrix, num_axes, like) + + +def _post_process(density_matrix, num_axes, like): + r""" + This post processor is necessary to ensure that the density matrix is in the correct format, i.e. the original tensor form, instead of the pure matrix form, as requested by all the other more fundamental chore functions in the module (again from some legacy code). + """ + density_matrix = np.reshape(density_matrix, (-1,) + (2,) * num_axes) + dtype = str(density_matrix.dtype) + floating_single = "float32" in dtype or "complex64" in dtype + dtype = "complex64" if floating_single else "complex128" + dtype = "complex128" if like == "tensorflow" else dtype + return math.cast(math.asarray(density_matrix, like=like), dtype) diff --git a/pennylane/drawer/draw.py b/pennylane/drawer/draw.py index df9e6e76829..2a3f1202560 100644 --- a/pennylane/drawer/draw.py +++ b/pennylane/drawer/draw.py @@ -385,7 +385,9 @@ def draw_mpl( fontsize (float or str): fontsize for text. Valid strings are ``{'xx-small', 'x-small', 'small', 'medium', large', 'x-large', 'xx-large'}``. Default is ``14``. - wire_options (dict): matplotlib formatting options for the wire lines + wire_options (dict): matplotlib formatting options for the wire lines. In addition to + standard options, options per wire can be specified with ``wire_label: options`` + pairs, also see examples below. label_options (dict): matplotlib formatting options for the wire labels show_wire_labels (bool): Whether or not to show the wire labels. active_wire_notches (bool): whether or not to add notches indicating active wires. @@ -458,7 +460,8 @@ def circuit2(x, y): **Wires:** - The keywords ``wire_order`` and ``show_all_wires`` control the location of wires from top to bottom. + The keywords ``wire_order`` and ``show_all_wires`` control the location of wires + from top to bottom. .. code-block:: python @@ -470,8 +473,8 @@ def circuit2(x, y): :width: 60% :target: javascript:void(0); - If a wire is in ``wire_order``, but not in the ``tape``, it will be omitted by default. Only by selecting - ``show_all_wires=True`` will empty wires be displayed. + If a wire is in ``wire_order``, but not in the ``tape``, it will be omitted by default. + Only by selecting ``show_all_wires=True`` will empty wires be displayed. .. code-block:: python @@ -568,6 +571,25 @@ def circuit2(x, y): :width: 60% :target: javascript:void(0); + + Additionally, ``wire_options`` may contain sub-dictionaries of matplotlib options assigned + to separate wire labels, which will control the line style for the respective individual wires. + + .. code-block:: python + + wire_options = { + 'color': 'teal', # all wires but wire 2 will be teal + 'linewidth': 5, # all wires but wire 2 will be bold + 2: {'color': 'orange', 'linestyle': '--'}, # wire 2 will be orange and dashed + } + fig, ax = qml.draw_mpl(circuit, wire_options=wire_options)(1.2345,1.2345) + fig.show() + + .. figure:: ../../_static/draw_mpl/per_wire_options.png + :align: center + :width: 60% + :target: javascript:void(0); + **Levels:** The ``level`` keyword argument allows one to select a subset of the transforms to apply on the ``QNode`` diff --git a/pennylane/drawer/mpldrawer.py b/pennylane/drawer/mpldrawer.py index eada58dcad9..64e92a50cfe 100644 --- a/pennylane/drawer/mpldrawer.py +++ b/pennylane/drawer/mpldrawer.py @@ -294,11 +294,25 @@ def __init__(self, n_layers, n_wires, c_wires=0, wire_options=None, figsize=None if wire_options is None: wire_options = {} - # adding wire lines - self._wire_lines = [ - plt.Line2D((-1, self.n_layers), (wire, wire), zorder=1, **wire_options) - for wire in range(self.n_wires) - ] + # Separate global options from per wire options + global_options = {k: v for k, v in wire_options.items() if not isinstance(v, dict)} + wire_specific_options = {k: v for k, v in wire_options.items() if isinstance(v, dict)} + + # Adding wire lines with individual styles based on wire_options + self._wire_lines = [] + for wire in range(self.n_wires): + specific_options = wire_specific_options.get(wire, {}) + line_options = {**global_options, **specific_options} + + # Create Line2D with the combined options + line = plt.Line2D( + (-1, self.n_layers), + (wire, wire), + zorder=1, + **line_options, + ) + self._wire_lines.append(line) + for line in self._wire_lines: self._ax.add_line(line) diff --git a/pennylane/drawer/tape_mpl.py b/pennylane/drawer/tape_mpl.py index 76c485cc67b..b1aa8f93af4 100644 --- a/pennylane/drawer/tape_mpl.py +++ b/pennylane/drawer/tape_mpl.py @@ -303,7 +303,9 @@ def tape_mpl( fontsize (float or str): fontsize for text. Valid strings are ``{'xx-small', 'x-small', 'small', 'medium', large', 'x-large', 'xx-large'}``. Default is ``14``. - wire_options (dict): matplotlib formatting options for the wire lines + wire_options (dict): matplotlib formatting options for the wire lines. In addition to + standard options, options per wire can be specified with ``wire_label: options`` + pairs, also see examples below. label_options (dict): matplotlib formatting options for the wire labels show_wire_labels (bool): Whether or not to show the wire labels. active_wire_notches (bool): whether or not to add notches indicating active wires. @@ -328,7 +330,7 @@ def tape_mpl( measurements = [qml.expval(qml.Z(0))] tape = qml.tape.QuantumTape(ops, measurements) - fig, ax = tape_mpl(tape) + fig, ax = qml.drawer.tape_mpl(tape) fig.show() .. figure:: ../../_static/tape_mpl/default.png @@ -350,7 +352,7 @@ def tape_mpl( measurements = [qml.expval(qml.Z(0))] tape2 = qml.tape.QuantumTape(ops, measurements) - fig, ax = tape_mpl(tape2, decimals=2) + fig, ax = qml.drawer.tape_mpl(tape2, decimals=2) .. figure:: ../../_static/tape_mpl/decimals.png :align: center @@ -363,7 +365,7 @@ def tape_mpl( .. code-block:: python - fig, ax = tape_mpl(tape, wire_order=[3,2,1,0]) + fig, ax = qml.drawer.tape_mpl(tape, wire_order=[3,2,1,0]) .. figure:: ../../_static/tape_mpl/wire_order.png :align: center @@ -375,7 +377,7 @@ def tape_mpl( .. code-block:: python - fig, ax = tape_mpl(tape, wire_order=["aux"], show_all_wires=True) + fig, ax = qml.drawer.tape_mpl(tape, wire_order=["aux"], show_all_wires=True) .. figure:: ../../_static/tape_mpl/show_all_wires.png :align: center @@ -389,7 +391,7 @@ def tape_mpl( .. code-block:: python - fig, ax = tape_mpl(tape) + fig, ax = qml.drawer.tape_mpl(tape) fig.suptitle("My Circuit", fontsize="xx-large") options = {'facecolor': "white", 'edgecolor': "#f57e7e", "linewidth": 6, "zorder": -1} @@ -413,7 +415,7 @@ def tape_mpl( .. code-block:: python - fig, ax = tape_mpl(tape, style='sketch') + fig, ax = qml.drawer.tape_mpl(tape, style='sketch') .. figure:: ../../_static/tape_mpl/sketch_style.png :align: center @@ -437,7 +439,7 @@ def tape_mpl( plt.rcParams['lines.linewidth'] = 5 plt.rcParams['figure.facecolor'] = 'ghostwhite' - fig, ax = tape_mpl(tape, style="rcParams") + fig, ax = qml.drawer.tape_mpl(tape, style="rcParams") .. figure:: ../../_static/tape_mpl/rcparams.png :align: center @@ -450,7 +452,7 @@ def tape_mpl( .. code-block:: python - fig, ax = tape_mpl(tape, wire_options={'color':'teal', 'linewidth': 5}, + fig, ax = qml.drawer.tape_mpl(tape, wire_options={'color':'teal', 'linewidth': 5}, label_options={'size': 20}) .. figure:: ../../_static/tape_mpl/wires_labels.png @@ -458,6 +460,22 @@ def tape_mpl( :width: 60% :target: javascript:void(0); + Additionally, ``wire_options`` may contain sub-dictionaries of matplotlib options assigned + to separate wire labels, which will control the line style for the respective individual wires. + + .. code-block:: python + + wire_options = { + 'color': 'teal', # all wires but wire 2 will be teal + 'linewidth': 5, # all wires but wire 2 will be bold + 2: {'color': 'orange', 'linestyle': '--'}, # wire 2 will be orange and dashed + } + fig, ax = qml.drawer.tape_mpl(tape, wire_options=wire_options) + + .. figure:: ../../_static/tape_mpl/per_wire_options.png + :align: center + :width: 60% + :target: javascript:void(0); """ restore_params = {} diff --git a/pennylane/math/utils.py b/pennylane/math/utils.py index 9d716062382..caa47338144 100644 --- a/pennylane/math/utils.py +++ b/pennylane/math/utils.py @@ -466,6 +466,9 @@ def requires_grad(tensor, interface=None): interface (str): The name of the interface. Will be determined automatically if not provided. + Returns: + bool: whether the tensor is trainable or not. + **Example** Calling this function on a PennyLane NumPy array: @@ -539,6 +542,9 @@ def in_backprop(tensor, interface=None): interface (str): The name of the interface. Will be determined automatically if not provided. + Returns: + bool: whether the tensor is in a backpropagation environment or not. + **Example** >>> x = tf.Variable([0.6, 0.1]) diff --git a/pennylane/operation.py b/pennylane/operation.py index b2073f7a025..f405ec259b8 100644 --- a/pennylane/operation.py +++ b/pennylane/operation.py @@ -1234,6 +1234,7 @@ def _check_batching(self): "Broadcasting was attempted but the broadcasted dimensions " f"do not match: {first_dims}." ) + self._batch_size = first_dims[0] def __repr__(self) -> str: diff --git a/pennylane/ops/identity.py b/pennylane/ops/identity.py index fd4a6d80e42..dc6a401af25 100644 --- a/pennylane/ops/identity.py +++ b/pennylane/ops/identity.py @@ -80,10 +80,12 @@ def __repr__(self): """String representation.""" if len(self.wires) == 0: return "I()" - wire = self.wires[0] - if isinstance(wire, str): - return f"I('{wire}')" - return f"I({wire})" + if len(self.wires) == 1: + wire = self.wires[0] + if isinstance(wire, str): + return f"I('{wire}')" + return f"I({wire})" + return f"I({self.wires.tolist()})" @property def name(self): diff --git a/pennylane/ops/op_math/adjoint.py b/pennylane/ops/op_math/adjoint.py index f0ebacfe5e2..36adf191849 100644 --- a/pennylane/ops/op_math/adjoint.py +++ b/pennylane/ops/op_math/adjoint.py @@ -18,7 +18,7 @@ from typing import Callable, overload import pennylane as qml -from pennylane.capture.capture_diff import create_non_jvp_primitive +from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.compiler import compiler from pennylane.math import conj, moveaxis, transpose from pennylane.operation import Observable, Operation, Operator @@ -193,7 +193,7 @@ def _get_adjoint_qfunc_prim(): # if capture is enabled, jax should be installed import jax # pylint: disable=import-outside-toplevel - adjoint_prim = create_non_jvp_primitive()("adjoint_transform") + adjoint_prim = create_non_interpreted_prim()("adjoint_transform") adjoint_prim.multiple_results = True @adjoint_prim.def_impl diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 4e5db25c154..1f682c9b290 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -20,7 +20,7 @@ import pennylane as qml from pennylane import QueuingManager -from pennylane.capture.capture_diff import create_non_jvp_primitive +from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.capture.flatfn import FlatFn from pennylane.compiler import compiler from pennylane.measurements import MeasurementValue @@ -690,7 +690,7 @@ def _get_cond_qfunc_prim(): import jax # pylint: disable=import-outside-toplevel - cond_prim = create_non_jvp_primitive()("cond") + cond_prim = create_non_interpreted_prim()("cond") cond_prim.multiple_results = True @cond_prim.def_impl diff --git a/pennylane/ops/op_math/controlled.py b/pennylane/ops/op_math/controlled.py index 6ac000cefd5..5dfb11a05f0 100644 --- a/pennylane/ops/op_math/controlled.py +++ b/pennylane/ops/op_math/controlled.py @@ -28,7 +28,7 @@ import pennylane as qml from pennylane import math as qmlmath from pennylane import operation -from pennylane.capture.capture_diff import create_non_jvp_primitive +from pennylane.capture.capture_diff import create_non_interpreted_prim from pennylane.compiler import compiler from pennylane.operation import Operator from pennylane.wires import Wires @@ -235,7 +235,7 @@ def _get_ctrl_qfunc_prim(): # if capture is enabled, jax should be installed import jax # pylint: disable=import-outside-toplevel - ctrl_prim = create_non_jvp_primitive()("ctrl_transform") + ctrl_prim = create_non_interpreted_prim()("ctrl_transform") ctrl_prim.multiple_results = True @ctrl_prim.def_impl diff --git a/pennylane/templates/embeddings/amplitude.py b/pennylane/templates/embeddings/amplitude.py index 0b52123ae75..eef3aa462db 100644 --- a/pennylane/templates/embeddings/amplitude.py +++ b/pennylane/templates/embeddings/amplitude.py @@ -55,9 +55,9 @@ class AmplitudeEmbedding(StatePrep): @qml.qnode(dev) def circuit(f=None): qml.AmplitudeEmbedding(features=f, wires=range(2)) - return qml.expval(qml.Z(0)), qml.state() + return qml.state() - res, state = circuit(f=[1/2, 1/2, 1/2, 1/2]) + state = circuit(f=[1/2, 1/2, 1/2, 1/2]) The final state of the device is - up to a global phase - equivalent to the input passed to the circuit: @@ -79,9 +79,9 @@ def circuit(f=None): @qml.qnode(dev) def circuit(f=None): qml.AmplitudeEmbedding(features=f, wires=range(2), normalize=True) - return qml.expval(qml.Z(0)), qml.state() + return qml.state() - res, state = circuit(f=[15, 15, 15, 15]) + state = circuit(f=[15, 15, 15, 15]) >>> state tensor([0.5+0.j, 0.5+0.j, 0.5+0.j, 0.5+0.j], requires_grad=True) @@ -98,9 +98,9 @@ def circuit(f=None): @qml.qnode(dev) def circuit(f=None): qml.AmplitudeEmbedding(features=f, wires=range(2), pad_with=0.) - return qml.expval(qml.Z(0)), qml.state() + return qml.state() - res, state = circuit(f=[1/sqrt(2), 1/sqrt(2)]) + state = circuit(f=[1/sqrt(2), 1/sqrt(2)]) >>> state tensor([0.70710678+0.j, 0.70710678+0.j, 0. +0.j, 0. +0.j], requires_grad=True) diff --git a/pennylane/typing.py b/pennylane/typing.py index 16de82bb6bd..9094a418fe6 100644 --- a/pennylane/typing.py +++ b/pennylane/typing.py @@ -35,7 +35,7 @@ class TensorLikeMETA(type): def __instancecheck__(cls, other): """Dunder method used to check if an object is a `TensorLike` instance.""" return ( - isinstance(other, _TensorLike.__args__) # TODO: Remove __args__ when python>=3.10 + isinstance(other, _TensorLike) or _is_jax(other) or _is_torch(other) or _is_tensorflow(other) @@ -44,7 +44,7 @@ def __instancecheck__(cls, other): def __subclasscheck__(cls, other): """Dunder method that checks if a class is a subclass of ``TensorLike``.""" return ( - issubclass(other, _TensorLike.__args__) # TODO: Remove __args__ when python>=3.10 + issubclass(other, _TensorLike) or _is_jax(other, subclass=True) or _is_torch(other, subclass=True) or _is_tensorflow(other, subclass=True) diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 191df685388..3e625aecbc7 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -359,3 +359,293 @@ def circuit(x): xt = -0.6 jvp = jax.jvp(circuit, (x,), (xt,)) assert qml.math.allclose(jvp, (qml.math.cos(x), -qml.math.sin(x) * xt)) + + +class TestQNodeVmapIntegration: + """Tests for integrating JAX vmap with the QNode primitive.""" + + @pytest.mark.parametrize( + "input, expected_shape", + [ + (jax.numpy.array([0.1]), (1,)), + (jax.numpy.array([0.1, 0.2]), (2,)), + (jax.numpy.array([0.1, 0.2, 0.3]), (3,)), + ], + ) + def test_qnode_vmap(self, input, expected_shape): + """Test that JAX can vmap over the QNode primitive via a registered batching rule.""" + + @qml.qnode(qml.device("default.qubit", wires=1)) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(jax.vmap(circuit))(input) + eqn0 = jaxpr.eqns[0] + + assert len(eqn0.outvars) == 1 + assert eqn0.outvars[0].aval.shape == expected_shape + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, input) + assert qml.math.allclose(res, jax.numpy.cos(input)) + + @pytest.mark.parametrize("x64_mode", (True, False)) + def test_qnode_vmap_x64_mode(self, x64_mode): + """Test that JAX can vmap over the QNode primitive with x64 mode enabled/disabled.""" + + initial_mode = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", x64_mode) + dtype = jax.numpy.float64 if x64_mode else jax.numpy.float32 + + @qml.qnode(qml.device("default.qubit", wires=1)) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + x = jax.numpy.array([0.1, 0.2, 0.3], dtype=dtype) + + jaxpr = jax.make_jaxpr(jax.vmap(circuit))(x) + eqn0 = jaxpr.eqns[0] + + assert len(eqn0.outvars) == 1 + assert eqn0.outvars[0].aval == jax.core.ShapedArray((3,), dtype) + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + assert qml.math.allclose(res, jax.numpy.cos(x)) + + jax.config.update("jax_enable_x64", initial_mode) + + def test_vmap_mixed_arguments(self): + """Test vmap with a mix of batched and non-batched arguments.""" + + @qml.qnode(qml.device("default.qubit", wires=2)) + def circuit(arr1, scalar1, arr2, scalar2): + qml.RX(arr1, 0) + qml.RY(scalar1, 0) + qml.RY(arr2, 1) + qml.RZ(scalar2, 1) + return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1)) + + arr1 = jax.numpy.array([0.1, 0.2, 0.3]) + arr2 = jax.numpy.array([0.2, 0.4, 0.6]) + scalar1 = 1.0 + scalar2 = 2.0 + + jaxpr = jax.make_jaxpr(jax.vmap(circuit, in_axes=(0, None, 0, None)))( + arr1, scalar1, arr2, scalar2 + ) + + assert len(jaxpr.out_avals) == 2 + assert jaxpr.out_avals[0].shape == (3,) + assert jaxpr.out_avals[1].shape == (3,) + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, arr1, scalar1, arr2, scalar2) + assert qml.math.allclose(res, circuit(arr1, scalar1, arr2, scalar2)) + # compare with jax.vmap to cover all code paths + assert qml.math.allclose( + res, jax.vmap(circuit, in_axes=(0, None, 0, None))(arr1, scalar1, arr2, scalar2) + ) + + def test_vmap_multiple_measurements(self): + """Test that JAX can vmap over the QNode primitive with multiple measurements.""" + + @qml.qnode(qml.device("default.qubit", wires=4, shots=5)) + def circuit(x): + qml.DoubleExcitation(x, wires=[0, 1, 2, 3]) + return qml.sample(), qml.probs(wires=(0, 1, 2)), qml.expval(qml.Z(0)) + + x = jax.numpy.array([1.0, 2.0]) + jaxpr = jax.make_jaxpr(jax.vmap(circuit))(x) + + res1_vmap, res2_vmap, res3_vmap = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + + assert len(jaxpr.eqns[0].outvars) == 3 + assert jaxpr.out_avals[0].shape == (2, 5, 4) + assert jaxpr.out_avals[1].shape == (2, 8) + assert jaxpr.out_avals[2].shape == (2,) + + assert qml.math.allclose(res1_vmap, jax.numpy.zeros((2, 5, 4))) + assert qml.math.allclose( + res2_vmap, jax.numpy.array([[1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0]]) + ) + assert qml.math.allclose(res3_vmap, jax.numpy.array([1.0, 1.0])) + + def test_qnode_vmap_closure(self): + """Test that JAX can vmap over the QNode primitive with closure variables.""" + + const = jax.numpy.array(2.0) + + @qml.qnode(qml.device("default.qubit", wires=2)) + def circuit(x): + qml.RX(x, 0) + qml.RY(const, 1) + return qml.probs(wires=[0, 1]) + + x = jax.numpy.array([1.0, 2.0, 3.0]) + jaxpr = jax.make_jaxpr(jax.vmap(circuit))(x) + eqn0 = jaxpr.eqns[0] + + assert len(eqn0.invars) == 2 # one closure variable, one (batched) arg + assert eqn0.invars[0].aval.shape == () + assert eqn0.invars[1].aval.shape == (3,) + + assert len(eqn0.outvars) == 1 + assert eqn0.outvars[0].aval.shape == (3, 4) + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + assert qml.math.allclose(res, circuit(x)) + + def test_qnode_vmap_closure_error(self): + """Test that an error is raised when trying to vmap over a batched non-scalar closure variable.""" + dev = qml.device("default.qubit", wires=2) + + const = jax.numpy.array([2.0, 6.6]) + + @qml.qnode(dev) + def circuit(x): + qml.RY(x, 0) + qml.RX(const, wires=0) + return qml.expval(qml.PauliZ(0)) + + with pytest.raises( + ValueError, match="Only scalar constants are currently supported with jax.vmap." + ): + jax.make_jaxpr(jax.vmap(circuit))(jax.numpy.array([0.1, 0.2])) + + def test_vmap_overriding_shots(self): + """Test that the number of shots can be overridden on call with vmap.""" + + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev) + # pylint: disable=unused-argument + def circuit(x): + return qml.sample() + + x = jax.numpy.array([1.0, 2.0, 3.0]) + + jaxpr = jax.make_jaxpr(jax.vmap(partial(circuit, shots=50), in_axes=0))(x) + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + + assert len(jaxpr.eqns) == 1 + eqn0 = jaxpr.eqns[0] + + assert eqn0.primitive == qnode_prim + assert eqn0.params["device"] == dev + assert eqn0.params["shots"] == qml.measurements.Shots(50) + assert ( + eqn0.params["qfunc_jaxpr"].eqns[0].primitive + == qml.measurements.SampleMP._wires_primitive + ) + + assert eqn0.outvars[0].aval.shape == (3, 50) + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + assert qml.math.allclose(res, jax.numpy.zeros((3, 50))) + + def test_vmap_error_indexing(self): + """Test that an IndexError is raised when indexing a batched parameter.""" + + @qml.qnode(qml.device("default.qubit", wires=2)) + def circuit(vec, scalar): + qml.RX(vec[0], 0) + qml.RY(scalar, 1) + return qml.expval(qml.Z(0)) + + with pytest.raises(IndexError): + jax.make_jaxpr(jax.vmap(circuit, in_axes=(0, None)))( + jax.numpy.array([1.0, 2.0, 3.0]), 5.0 + ) + + def test_vmap_error_empty_array(self): + """Test that an error is raised when passing an empty array to vmap.""" + + @qml.qnode(qml.device("default.qubit", wires=1)) + def circuit(x): + qml.RX(x, wires=0) + return qml.expval(qml.Z(0)) + + with pytest.raises(ValueError, match="Empty tensors are not supported with jax.vmap."): + jax.make_jaxpr(jax.vmap(circuit))(jax.numpy.array([])) + + def test_warning_bypass_vmap(self): + """Test that a warning is raised when bypassing vmap.""" + dev = qml.device("default.qubit", wires=4) + + @qml.qnode(dev) + def circuit(param_array, param_array_2): + qml.RX(param_array, wires=2) + qml.DoubleExcitation(param_array_2[0], wires=[0, 1, 2, 3]) + return qml.expval(qml.PauliZ(0)) + + param_array = jax.numpy.array([1.0, 1.2, 1.3]) + param_array_2 = jax.numpy.array([2.0, 2.1, 2.2]) + + with pytest.warns(UserWarning, match="Argument at index 1 has more"): + jax.make_jaxpr(jax.vmap(circuit, in_axes=(0, None)))(param_array, param_array_2) + + def test_qnode_pytree_input_vmap(self): + """Test that we can capture and execute a qnode with a pytree input and vmap.""" + + @qml.qnode(qml.device("default.qubit", wires=2)) + def circuit(x): + qml.RX(x["val"], wires=x["wires"]) + return qml.expval(qml.Z(wires=x["wires"])) + + x = {"val": jax.numpy.array([0.1, 0.2]), "wires": 0} + jaxpr = jax.make_jaxpr(jax.vmap(circuit, in_axes=({"val": 0, "wires": None},)))(x) + + assert len(jaxpr.eqns[0].invars) == 2 + + assert len(jaxpr.eqns[0].outvars) == 1 + assert jaxpr.eqns[0].outvars[0].aval.shape == (2,) + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x["val"], x["wires"]) + assert qml.math.allclose(res, jax.numpy.cos(x["val"])) + + def test_qnode_deep_pytree_input_vmap(self): + """Test vmap over qnodes with deep pytree inputs.""" + + @qml.qnode(qml.device("default.qubit", wires=2)) + def circuit(x): + qml.RX(x["data"]["val"], wires=x["data"]["wires"]) + return qml.expval(qml.Z(wires=x["data"]["wires"])) + + x = {"data": {"val": jax.numpy.array([0.1, 0.2]), "wires": 0}} + jaxpr = jax.make_jaxpr(jax.vmap(circuit, in_axes=({"data": {"val": 0, "wires": None}},)))(x) + + assert len(jaxpr.eqns[0].invars) == 2 + + assert len(jaxpr.eqns[0].outvars) == 1 + assert jaxpr.eqns[0].outvars[0].aval.shape == (2,) + + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x["data"]["val"], x["data"]["wires"]) + assert qml.math.allclose(res, jax.numpy.cos(x["data"]["val"])) + + def test_qnode_pytree_output_vmap(self): + """Test that we can capture and execute a qnode with a pytree output and vmap.""" + + @qml.qnode(qml.device("default.qubit", wires=2)) + def circuit(x): + qml.RX(x, 0) + return {"a": qml.expval(qml.Z(0)), "b": qml.expval(qml.Y(0))} + + x = jax.numpy.array([1.2, 1.3]) + out = jax.vmap(circuit)(x) + + assert qml.math.allclose(out["a"], jax.numpy.cos(x)) + assert qml.math.allclose(out["b"], -jax.numpy.sin(x)) + assert list(out.keys()) == ["a", "b"] + + def test_error_multidimensional_batching(self): + """Test that an error is raised when trying to vmap over a multidimensional batched parameter.""" + + @qml.qnode(qml.device("default.qubit", wires=2)) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + with pytest.raises( + ValueError, match="Currently, only single-dimension batching is supported" + ): + jax.make_jaxpr(jax.vmap(circuit))(jax.numpy.array([[0.1, 0.2], [0.3, 0.4]])) diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py index 0ac9bc2cffd..f41bbdafe12 100644 --- a/tests/devices/default_qubit/test_default_qubit.py +++ b/tests/devices/default_qubit/test_default_qubit.py @@ -1844,6 +1844,9 @@ def test_postselection_valid_finite_shots(self, param, mp, shots, interface, use if use_jit and (interface != "jax" or isinstance(shots, tuple)): pytest.skip("Cannot JIT in non-JAX interfaces, or with shot vectors.") + if isinstance(mp, qml.measurements.ClassicalShadowMP): + mp.seed = seed + dev = qml.device("default.qubit", seed=seed) param = qml.math.asarray(param, like=interface) diff --git a/tests/devices/qubit_mixed/test_qubit_mixed_initialize_state.py b/tests/devices/qubit_mixed/test_qubit_mixed_initialize_state.py new file mode 100644 index 00000000000..2510a038b0a --- /dev/null +++ b/tests/devices/qubit_mixed/test_qubit_mixed_initialize_state.py @@ -0,0 +1,110 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for initialize_state in devices/qubit_mixed/initialize_state.""" + +import pytest + +import pennylane as qml +from pennylane import StatePrep, math +from pennylane import numpy as np +from pennylane.devices.qubit_mixed import create_initial_state +from pennylane.operation import StatePrepBase + +ml_interfaces = ["numpy", "autograd", "jax", "torch", "tensorflow"] + + +def allzero_vec(num_wires, interface="numpy"): + """Returns the state vector of the all-zero state.""" + state = np.zeros(2**num_wires, dtype=complex) + state[0] = 1 + state = math.asarray(state, like=interface) + return state + + +def allzero_dm(num_wires, interface="numpy"): + """Returns the density matrix of the all-zero state.""" + num_axes = 2 * num_wires + dm = np.zeros((2,) * num_axes, dtype=complex) + dm[(0,) * num_axes] = 1 + dm = math.asarray(dm, like=interface) + return dm + + +@pytest.mark.all_interfaces +@pytest.mark.parametrize("interface", ml_interfaces) +class TestInitializeState: + """Test the functions in initialize_state.py""" + + # pylint:disable=unused-argument,too-few-public-methods + class DefaultPrep(StatePrepBase): + """A dummy class that assumes it was given a state vector.""" + + num_wires = qml.operation.AllWires + + def __init__(self, *args, **kwargs): + self.dtype = kwargs.pop("dtype", None) + super().__init__(*args, **kwargs) + + def state_vector(self, wire_order=None): + sv = self.parameters[0] + if self.dtype is not None: + sv = qml.math.cast(sv, self.dtype) + return sv + + def test_create_initial_state_no_state_prep(self, interface): + """Tests that create_initial_state works without a state-prep operation.""" + wires = [0, 1] + num_wires = len(wires) + state = create_initial_state(wires, like=interface) + + state_correct = allzero_dm(num_wires, interface) + assert math.allequal(state, state_correct) + assert math.get_interface(state) == interface + assert "complex" in str(state.dtype) + + def test_create_initial_state_with_dummy_state_prep(self, interface): + """Tests that create_initial_state works with a state-prep operation.""" + wires = [0, 1] + num_wires = len(wires) + + vec_correct = allzero_vec(num_wires, interface) + state_correct = allzero_dm(num_wires, interface) + prep_op = self.DefaultPrep(qml.math.array(vec_correct, like=interface), wires=wires) + state = create_initial_state(wires, prep_operation=prep_op, like=interface) + assert math.allequal(state, state_correct) + assert math.get_interface(state) == interface + + def test_create_initial_state_with_StatePrep(self, interface): + """Tests that create_initial_state works with a state-prep operation.""" + wires = [0, 1] + num_wires = len(wires) + # The following 2 lines are for reusing the statevec code on the density matrices + vec_correct = allzero_vec(num_wires, interface) + state_correct = allzero_dm(num_wires, interface) + state_correct_flatten = math.reshape(vec_correct, [-1]) + prep_op = StatePrep(qml.math.array(state_correct_flatten, like=interface), wires=wires) + state = create_initial_state(wires, prep_operation=prep_op, like=interface) + assert math.allequal(state, state_correct) + assert math.get_interface(state) == interface + + def test_create_initial_state_with_QubitDensityMatrix(self, interface): + """Tests that create_initial_state works with a state-prep operation.""" + wires = [0, 1] + num_wires = len(wires) + # The following 2 lines are for reusing the statevec code on the density matrices + state_correct = allzero_dm(num_wires, interface) + prep_op = qml.QubitDensityMatrix(qml.math.array(state_correct, like=interface), wires=wires) + state = create_initial_state(wires, prep_operation=prep_op, like=interface) + assert math.allequal(state, state_correct) + assert math.get_interface(state) == interface diff --git a/tests/devices/test_default_clifford.py b/tests/devices/test_default_clifford.py index fac9de067cf..ba871dc961b 100644 --- a/tests/devices/test_default_clifford.py +++ b/tests/devices/test_default_clifford.py @@ -218,10 +218,10 @@ def circuit_fn(): qml.sum(qml.PauliZ(0), qml.s_prod(2.0, qml.PauliY(1))), ], ) -def test_meas_var(shots, ops): +def test_meas_var(shots, ops, seed): """Test that variance measurements with `default.clifford` is possible and agrees with `default.qubit`.""" - dev_c = qml.device("default.clifford", shots=shots) + dev_c = qml.device("default.clifford", shots=shots, seed=seed) dev_q = qml.device("default.qubit") def circuit_fn(): diff --git a/tests/drawer/test_draw_mpl.py b/tests/drawer/test_draw_mpl.py index f2e7173298b..23ca78d19a2 100644 --- a/tests/drawer/test_draw_mpl.py +++ b/tests/drawer/test_draw_mpl.py @@ -328,6 +328,75 @@ def test_wire_options(self): assert w.get_color() == "black" assert w.get_linewidth() == 4 + @qml.qnode(dev) + def f_circ(x): + """Circuit on ten qubits.""" + qml.RX(x, wires=0) + for w in range(10): + qml.Hadamard(w) + return qml.expval(qml.PauliZ(0) @ qml.PauliY(1)) + + # All wires are orange + wire_options = {"color": "orange"} + _, ax = qml.draw_mpl(f_circ, wire_options=wire_options)(0.52) + + for w in ax.lines: + assert w.get_color() == "orange" + + # Wires are orange and cyan + wire_options = {0: {"color": "orange"}, 1: {"color": "cyan"}} + _, ax = qml.draw_mpl(f_circ, wire_options=wire_options)(0.52) + + assert ax.lines[0].get_color() == "orange" + assert ax.lines[1].get_color() == "cyan" + assert ax.lines[2].get_color() == "black" + + # Make all wires cyan and bold, + # except for wires 2 and 6, which are dashed and another color + wire_options = { + "color": "cyan", + "linewidth": 5, + 2: {"linestyle": "--", "color": "red"}, + 6: {"linestyle": "--", "color": "orange", "linewidth": 1}, + } + _, ax = qml.draw_mpl(f_circ, wire_options=wire_options)(0.52) + + for i, w in enumerate(ax.lines): + if i == 2: + assert w.get_color() == "red" + assert w.get_linestyle() == "--" + assert w.get_linewidth() == 5 + elif i == 6: + assert w.get_color() == "orange" + assert w.get_linestyle() == "--" + assert w.get_linewidth() == 1 + else: + assert w.get_color() == "cyan" + assert w.get_linestyle() == "-" + assert w.get_linewidth() == 5 + + wire_options = { + "linewidth": 5, + 2: {"linestyle": "--", "color": "red"}, + 6: {"linestyle": "--", "color": "orange"}, + } + + _, ax = qml.draw_mpl(f_circ, wire_options=wire_options)(0.52) + + for i, w in enumerate(ax.lines): + if i == 2: + assert w.get_color() == "red" + assert w.get_linestyle() == "--" + assert w.get_linewidth() == 5 + elif i == 6: + assert w.get_color() == "orange" + assert w.get_linestyle() == "--" + assert w.get_linewidth() == 5 + else: + assert w.get_color() == "black" + assert w.get_linestyle() == "-" + assert w.get_linewidth() == 5 + plt.close() diff --git a/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py b/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py index 3062841ba14..b1f6e86b9b7 100644 --- a/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py +++ b/tests/gradients/parameter_shift/test_parameter_shift_shot_vec.py @@ -507,13 +507,13 @@ class TestParameterShiftRule: @pytest.mark.parametrize("theta", angles) @pytest.mark.parametrize("shift", [np.pi / 2, 0.3]) @pytest.mark.parametrize("G", [qml.RX, qml.RY, qml.RZ, qml.PhaseShift]) - def test_pauli_rotation_gradient(self, mocker, G, theta, shift, broadcast): + def test_pauli_rotation_gradient(self, mocker, G, theta, shift, broadcast, seed): """Tests that the automatic gradients of Pauli rotations are correct.""" # pylint: disable=too-many-arguments spy = mocker.spy(qml.gradients.parameter_shift, "_get_operation_recipe") shot_vec = many_shots_shot_vector - dev = qml.device("default.qubit", wires=1, shots=shot_vec) + dev = qml.device("default.qubit", wires=1, shots=shot_vec, seed=seed) with qml.queuing.AnnotatedQueue() as q: qml.StatePrep(np.array([1.0, -1.0], requires_grad=False) / np.sqrt(2), wires=0) diff --git a/tests/ops/test_identity.py b/tests/ops/test_identity.py index 638bdc51de8..74892197cbd 100644 --- a/tests/ops/test_identity.py +++ b/tests/ops/test_identity.py @@ -18,8 +18,12 @@ import pennylane as qml from pennylane import Identity +op_wires = [[], [0], ["a"], [0, 1], ["a", "b", "c"], [100, "xasd", 12]] +op_repr = ["I()", "I(0)", "I('a')", "I([0, 1])", "I(['a', 'b', 'c'])", "I([100, 'xasd', 12])"] +op_params = tuple(zip(op_wires, op_repr)) -@pytest.mark.parametrize("wires", [[0], [0, 1], ["a", "b", "c"], [100, "xasd", 12]]) + +@pytest.mark.parametrize("wires", op_wires) class TestIdentity: # pylint: disable=protected-access def test_flatten_unflatten(self, wires): @@ -84,3 +88,10 @@ def test_matrix_representation(self, wires, tol): expected = np.eye(int(2 ** len(wires))) assert np.allclose(res_static, expected, atol=tol) assert np.allclose(res_dynamic, expected, atol=tol) + + +@pytest.mark.parametrize("wires, expected_repr", op_params) +def test_repr(wires, expected_repr): + """Test the operator's repr""" + op = Identity(wires=wires) + assert repr(op) == expected_repr diff --git a/tests/optimize/test_optimize_shot_adaptive.py b/tests/optimize/test_optimize_shot_adaptive.py index e1c4c3256ed..ed497bb1cd4 100644 --- a/tests/optimize/test_optimize_shot_adaptive.py +++ b/tests/optimize/test_optimize_shot_adaptive.py @@ -14,7 +14,6 @@ """Tests for the shot adaptive optimizer""" # pylint: disable=unused-argument import pytest -from flaky import flaky import pennylane as qml from pennylane import numpy as np @@ -101,16 +100,12 @@ class TestSingleShotGradientIntegration: """Integration tests to ensure that the single shot gradient is correctly computed for a variety of argument types.""" - dev = qml.device("default.qubit", wires=1, shots=100) - @staticmethod - @qml.qnode(dev) def cost_fn0(x): qml.RX(x, wires=0) return qml.expval(qml.PauliZ(0)) - @flaky(max_runs=3) - def test_single_argument_step(self, mocker, monkeypatch): + def test_single_argument_step(self, mocker, monkeypatch, seed): """Test that a simple QNode with a single argument correctly performs an optimization step, and that the single-shot gradients generated have the correct shape""" # pylint: disable=protected-access @@ -119,15 +114,17 @@ def test_single_argument_step(self, mocker, monkeypatch): spy_single_shot_qnodes = mocker.spy(opt, "_single_shot_qnode_gradients") spy_grad = mocker.spy(opt, "compute_grad") + dev = qml.device("default.qubit", wires=1, shots=100, seed=seed) x_init = np.array(0.5, requires_grad=True) - new_x = opt.step(self.cost_fn0, x_init) + qnode = qml.QNode(self.cost_fn0, device=dev) + new_x = opt.step(qnode, x_init) assert isinstance(new_x, np.tensor) assert new_x != x_init spy_grad.assert_called_once() spy_single_shot_qnodes.assert_called_once() - single_shot_grads = opt._single_shot_qnode_gradients(self.cost_fn0, [x_init], {}) + single_shot_grads = opt._single_shot_qnode_gradients(qnode, [x_init], {}) # assert single shot gradients are computed correctly assert len(single_shot_grads) == 1 @@ -143,7 +140,7 @@ def test_single_argument_step(self, mocker, monkeypatch): opt.s = [np.array(10)] # check that the gradient and variance are computed correctly - grad, grad_variance = opt.compute_grad(self.cost_fn0, [x_init], {}) + grad, grad_variance = opt.compute_grad(qnode, [x_init], {}) assert len(grad) == 1 assert len(grad_variance) == 1 assert np.allclose(grad, np.mean(single_shot_grads)) @@ -152,19 +149,18 @@ def test_single_argument_step(self, mocker, monkeypatch): # check that the gradient and variance are computed correctly # with a different shot budget opt.s = [np.array(5)] - grad, grad_variance = opt.compute_grad(self.cost_fn0, [x_init], {}) + grad, grad_variance = opt.compute_grad(qnode, [x_init], {}) assert len(grad) == 1 assert len(grad_variance) == 1 assert np.allclose(grad, np.mean(single_shot_grads[0][:5])) assert np.allclose(grad_variance, np.var(single_shot_grads[0][:5], ddof=1)) @staticmethod - @qml.qnode(dev) def cost_fn1(params): ansatz1(params) return qml.expval(qml.PauliZ(0)) - def test_single_array_argument_step(self, mocker, monkeypatch): + def test_single_array_argument_step(self, mocker, monkeypatch, seed): """Test that a simple QNode with a single array argument correctly performs an optimization step, and that the single-shot gradients generated have the correct shape""" # pylint: disable=protected-access @@ -172,14 +168,17 @@ def test_single_array_argument_step(self, mocker, monkeypatch): spy_single_shot_qnodes = mocker.spy(opt, "_single_shot_qnode_gradients") spy_grad = mocker.spy(opt, "compute_grad") + dev = qml.device("default.qubit", wires=1, shots=100, seed=seed) + x_init = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - new_x = opt.step(self.cost_fn1, x_init) + qnode = qml.QNode(self.cost_fn1, device=dev) + new_x = opt.step(qnode, x_init) assert isinstance(new_x, np.ndarray) assert not np.allclose(new_x, x_init) spy_single_shot_qnodes.assert_called_once() - single_shot_grads = opt._single_shot_qnode_gradients(self.cost_fn1, [x_init], {}) + single_shot_grads = opt._single_shot_qnode_gradients(qnode, [x_init], {}) spy_grad.assert_called_once() # assert single shot gradients are computed correctly @@ -196,7 +195,7 @@ def test_single_array_argument_step(self, mocker, monkeypatch): opt.s = [10 * np.ones([2, 3], dtype=np.int64)] # check that the gradient and variance are computed correctly - grad, grad_variance = opt.compute_grad(self.cost_fn1, [x_init], {}) + grad, grad_variance = opt.compute_grad(qnode, [x_init], {}) assert len(grad) == 1 assert len(grad_variance) == 1 assert grad[0].shape == x_init.shape @@ -210,7 +209,7 @@ def test_single_array_argument_step(self, mocker, monkeypatch): opt.s[0] = opt.s[0] // 2 # all array elements have a shot budget of 5 opt.s[0][0, 0] = 8 # set the shot budget of the zeroth element to 8 - grad, grad_variance = opt.compute_grad(self.cost_fn1, [x_init], {}) + grad, grad_variance = opt.compute_grad(qnode, [x_init], {}) assert len(grad) == 1 assert len(grad_variance) == 1 @@ -222,15 +221,12 @@ def test_single_array_argument_step(self, mocker, monkeypatch): assert np.allclose(grad[0][0, 1], np.mean(single_shot_grads[0][:5, 0, 1])) assert np.allclose(grad_variance[0][0, 1], np.var(single_shot_grads[0][:5, 0, 1], ddof=1)) - dev = qml.device("default.qubit", wires=2, shots=100) - @staticmethod - @qml.qnode(dev) def cost_fn2(params): ansatz2(params) return qml.expval(qml.PauliZ(0)) - def test_padded_single_array_argument_step(self, mocker, monkeypatch): + def test_padded_single_array_argument_step(self, mocker, monkeypatch, seed): """Test that a simple QNode with a single array argument with extra dimensions correctly performs an optimization step, and that the single-shot gradients generated have the correct shape""" @@ -241,13 +237,15 @@ def test_padded_single_array_argument_step(self, mocker, monkeypatch): shape = qml.StronglyEntanglingLayers.shape(n_layers=1, n_wires=2) x_init = np.ones(shape) * 0.5 - new_x = opt.step(self.cost_fn2, x_init) + dev = qml.device("default.qubit", wires=2, shots=100, seed=seed) + qnode = qml.QNode(self.cost_fn2, device=dev) + new_x = opt.step(qnode, x_init) assert isinstance(new_x, np.ndarray) assert not np.allclose(new_x, x_init) spy_single_shot_qnodes.assert_called_once() - single_shot_grads = opt._single_shot_qnode_gradients(self.cost_fn2, [x_init], {}) + single_shot_grads = opt._single_shot_qnode_gradients(qnode, [x_init], {}) spy_grad.assert_called_once() # assert single shot gradients are computed correctly @@ -264,7 +262,7 @@ def test_padded_single_array_argument_step(self, mocker, monkeypatch): opt.s = [10 * np.ones(shape, dtype=np.int64)] # check that the gradient and variance are computed correctly - grad, grad_variance = opt.compute_grad(self.cost_fn2, [x_init], {}) + grad, grad_variance = opt.compute_grad(qnode, [x_init], {}) assert len(grad) == 1 assert len(grad_variance) == 1 assert grad[0].shape == x_init.shape @@ -278,7 +276,7 @@ def test_padded_single_array_argument_step(self, mocker, monkeypatch): opt.s[0] = opt.s[0] // 2 # all array elements have a shot budget of 5 opt.s[0][0, 0, 0] = 8 # set the shot budget of the zeroth element to 8 - grad, grad_variance = opt.compute_grad(self.cost_fn2, [x_init], {}) + grad, grad_variance = opt.compute_grad(qnode, [x_init], {}) assert len(grad) == 1 assert len(grad_variance) == 1 @@ -297,13 +295,13 @@ def test_padded_single_array_argument_step(self, mocker, monkeypatch): # Step twice to ensure that `opt.s` does not get reshaped. # If it was reshaped, its shape would not match `new_x` # and an error would get raised. - _ = opt.step(self.cost_fn2, new_x) + _ = opt.step(qnode, new_x) - def test_multiple_argument_step(self, mocker, monkeypatch): + def test_multiple_argument_step(self, mocker, monkeypatch, seed): """Test that a simple QNode with multiple scalar arguments correctly performs an optimization step, and that the single-shot gradients generated have the correct shape""" # pylint: disable=protected-access - dev = qml.device("default.qubit", wires=1, shots=100) + dev = qml.device("default.qubit", wires=1, shots=100, seed=seed) @qml.qnode(dev) def circuit(x, y): @@ -356,11 +354,11 @@ def circuit(x, y): assert np.allclose(grad[p], np.mean(single_shot_grads[p][:s])) assert np.allclose(grad_variance[p], np.var(single_shot_grads[p][:s], ddof=1)) - def test_multiple_array_argument_step(self, mocker, monkeypatch): + def test_multiple_array_argument_step(self, mocker, monkeypatch, seed): """Test that a simple QNode with multiple array arguments correctly performs an optimization step, and that the single-shot gradients generated have the correct shape""" # pylint: disable=protected-access - dev = qml.device("default.qubit", wires=1, shots=100) + dev = qml.device("default.qubit", wires=1, shots=100, seed=seed) @qml.qnode(dev) def circuit(x, y): @@ -600,9 +598,9 @@ class TestOptimization: minimizes simple examples""" @pytest.mark.slow - def test_multi_qubit_rotation(self): + def test_multi_qubit_rotation(self, seed): """Test that multiple qubit rotation can be optimized""" - dev = qml.device("default.qubit", wires=2, shots=100) + dev = qml.device("default.qubit", wires=2, shots=100, seed=seed) @qml.qnode(dev) def circuit(x): @@ -630,7 +628,7 @@ def circuit(x): @pytest.mark.slow def test_vqe_optimization(self, seed): """Test that a simple VQE circuit can be optimized""" - dev = qml.device("default.qubit", wires=2, shots=100) + dev = qml.device("default.qubit", wires=2, shots=100, seed=seed) coeffs = [0.1, 0.2] obs = [qml.PauliZ(0), qml.PauliX(0)] H = qml.Hamiltonian(coeffs, obs) @@ -670,12 +668,11 @@ class TestStepAndCost: """Tests for the step_and_cost method""" @pytest.mark.slow - @flaky(max_runs=3) - def test_qnode_cost(self, tol): + def test_qnode_cost(self, tol, seed): """Test that the cost is correctly returned when using a QNode as the cost function""" - dev = qml.device("default.qubit", wires=1, shots=10) + dev = qml.device("default.qubit", wires=1, shots=10, seed=seed) @qml.qnode(dev, cache=False) def circuit(x): diff --git a/tests/transforms/test_qcut.py b/tests/transforms/test_qcut.py index 2ed9eeeffb3..82d2fefb4cd 100644 --- a/tests/transforms/test_qcut.py +++ b/tests/transforms/test_qcut.py @@ -5580,7 +5580,7 @@ def f(): assert np.isclose(res, res_expected, atol=1e-8) assert cut_circuit.tape.measurements[0].obs.grouping_indices == hamiltonian.grouping_indices - def test_template_with_hamiltonian(self): + def test_template_with_hamiltonian(self, seed): """Test cut with MPS Template""" pytest.importorskip("kahypar") @@ -5620,9 +5620,7 @@ def block(weights, wires): for idx, tape in enumerate(tapes): graph = qcut.tape_to_graph(tape) cut_graph = qcut.find_and_place_cuts( - graph=graph, - cut_strategy=cut_strategy, - replace_wire_cuts=True, + graph=graph, cut_strategy=cut_strategy, replace_wire_cuts=True, seed=seed ) frags, _ = qcut.fragment_graph(cut_graph) diff --git a/tests/transforms/test_split_non_commuting.py b/tests/transforms/test_split_non_commuting.py index b3374fc52ac..37931f70ea3 100644 --- a/tests/transforms/test_split_non_commuting.py +++ b/tests/transforms/test_split_non_commuting.py @@ -812,10 +812,12 @@ def circuit(angles): ), ], ) - def test_mixed_measurement_types(self, grouping_strategy, shots, params, expected_results): + def test_mixed_measurement_types( + self, grouping_strategy, shots, params, expected_results, seed + ): """Tests that a QNode with mixed measurement types is executed correctly""" - dev = qml.device("default.qubit", wires=2, shots=shots) + dev = qml.device("default.qubit", wires=2, shots=shots, seed=seed) obs_list = complex_obs_list if not qml.operation.active_new_opmath():