Skip to content

Commit

Permalink
Remove validation methods from primitive base classes (backport #11052)…
Browse files Browse the repository at this point in the history
… (#11532)

* Remove validation methods from primitive base classes (#11052)

* Remove validation methods from primitive base classes

This deprecates the argument validation methods from primitive base classes and moves them to separate helper functions. These methods unnecessarily bloat the base classes, and are odd to have when the BasePrimitive doesn't even define a run method to validate. There is no reason primitive implementations need to use the same validation as these base classes either. A follow up will be to remove the validation from the base `run` methods and have subclasses implement their own validation.

* Apply suggestions from code review

* Update qiskit/primitives/base/base_estimator.py

---------

Co-authored-by: Ikko Hamamura <[email protected]>
(cherry picked from commit 05d958b)

* Update qiskit/primitives/base/base_estimator.py

* Add missing import

---------

Co-authored-by: Christopher J. Wood <[email protected]>
Co-authored-by: Matthew Treinish <[email protected]>
Co-authored-by: Jake Lishman <[email protected]>
  • Loading branch information
4 people authored Feb 1, 2024
1 parent 3d3edd0 commit 6177feb
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 145 deletions.
39 changes: 10 additions & 29 deletions qiskit/primitives/base/base_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@
from qiskit.providers import JobV1 as Job
from qiskit.quantum_info.operators import SparsePauliOp
from qiskit.quantum_info.operators.base_operator import BaseOperator
from qiskit.utils.deprecation import deprecate_func

from ..utils import init_observable
from .base_primitive import BasePrimitive
from . import validation

if typing.TYPE_CHECKING:
from qiskit.opflow import PauliSumOp
Expand Down Expand Up @@ -175,18 +176,11 @@ def run(
TypeError: Invalid argument type given.
ValueError: Invalid argument values given.
"""
# Singular validation
circuits = self._validate_circuits(circuits)
observables = self._validate_observables(observables)
parameter_values = self._validate_parameter_values(
parameter_values,
default=[()] * len(circuits),
# Validation
circuits, observables, parameter_values = validation._validate_estimator_args(
circuits, observables, parameter_values
)

# Cross-validation
self._cross_validate_circuits_parameter_values(circuits, parameter_values)
self._cross_validate_circuits_observables(circuits, observables)

# Options
run_opts = copy(self.options)
run_opts.update_options(**run_options)
Expand All @@ -206,34 +200,21 @@ def _run(
parameter_values: tuple[tuple[float, ...], ...],
**run_options,
) -> T:
raise NotImplementedError("The subclass of BaseEstimator must implment `_run` method.")
raise NotImplementedError("The subclass of BaseEstimator must implement `_run` method.")

@staticmethod
@deprecate_func(since="0.46.0")
def _validate_observables(
observables: Sequence[BaseOperator | PauliSumOp | str] | BaseOperator | PauliSumOp | str,
) -> tuple[SparsePauliOp, ...]:
if isinstance(observables, str) or not isinstance(observables, Sequence):
observables = (observables,)
if len(observables) == 0:
raise ValueError("No observables were provided.")
return tuple(init_observable(obs) for obs in observables)
return validation._validate_observables(observables)

@staticmethod
@deprecate_func(since="0.46.0")
def _cross_validate_circuits_observables(
circuits: tuple[QuantumCircuit, ...], observables: tuple[BaseOperator | PauliSumOp, ...]
) -> None:
if len(circuits) != len(observables):
raise ValueError(
f"The number of circuits ({len(circuits)}) does not match "
f"the number of observables ({len(observables)})."
)
for i, (circuit, observable) in enumerate(zip(circuits, observables)):
if circuit.num_qubits != observable.num_qubits:
raise ValueError(
f"The number of qubits of the {i}-th circuit ({circuit.num_qubits}) does "
f"not match the number of qubits of the {i}-th observable "
f"({observable.num_qubits})."
)
return validation._cross_validate_circuits_observables(circuits, observables)

@property
def circuits(self) -> tuple[QuantumCircuit, ...]:
Expand Down
79 changes: 11 additions & 68 deletions qiskit/primitives/base/base_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from abc import ABC
from collections.abc import Sequence

import numpy as np

from qiskit.circuit import QuantumCircuit
from qiskit.providers import Options
from qiskit.utils.deprecation import deprecate_func

from . import validation


class BasePrimitive(ABC):
Expand Down Expand Up @@ -49,83 +50,25 @@ def set_options(self, **fields):
self._run_options.update_options(**fields)

@staticmethod
@deprecate_func(since="0.46.0")
def _validate_circuits(
circuits: Sequence[QuantumCircuit] | QuantumCircuit,
) -> tuple[QuantumCircuit, ...]:
if isinstance(circuits, QuantumCircuit):
circuits = (circuits,)
elif not isinstance(circuits, Sequence) or not all(
isinstance(cir, QuantumCircuit) for cir in circuits
):
raise TypeError("Invalid circuits, expected Sequence[QuantumCircuit].")
elif not isinstance(circuits, tuple):
circuits = tuple(circuits)
if len(circuits) == 0:
raise ValueError("No circuits were provided.")
return circuits
return validation._validate_circuits(circuits)

@staticmethod
@deprecate_func(since="0.46.0")
def _validate_parameter_values(
parameter_values: Sequence[Sequence[float]] | Sequence[float] | float | None,
default: Sequence[Sequence[float]] | Sequence[float] | None = None,
) -> tuple[tuple[float, ...], ...]:
# Allow optional (if default)
if parameter_values is None:
if default is None:
raise ValueError("No default `parameter_values`, optional input disallowed.")
parameter_values = default

# Support numpy ndarray
if isinstance(parameter_values, np.ndarray):
parameter_values = parameter_values.tolist()
elif isinstance(parameter_values, Sequence):
parameter_values = tuple(
vector.tolist() if isinstance(vector, np.ndarray) else vector
for vector in parameter_values
)

# Allow single value
if _isreal(parameter_values):
parameter_values = ((parameter_values,),)
elif isinstance(parameter_values, Sequence) and not any(
isinstance(vector, Sequence) for vector in parameter_values
):
parameter_values = (parameter_values,)

# Validation
if (
not isinstance(parameter_values, Sequence)
or not all(isinstance(vector, Sequence) for vector in parameter_values)
or not all(all(_isreal(value) for value in vector) for vector in parameter_values)
):
raise TypeError("Invalid parameter values, expected Sequence[Sequence[float]].")

return tuple(tuple(float(value) for value in vector) for vector in parameter_values)
return validation._validate_parameter_values(parameter_values, default=default)

@staticmethod
@deprecate_func(since="0.46.0")
def _cross_validate_circuits_parameter_values(
circuits: tuple[QuantumCircuit, ...], parameter_values: tuple[tuple[float, ...], ...]
) -> None:
if len(circuits) != len(parameter_values):
raise ValueError(
f"The number of circuits ({len(circuits)}) does not match "
f"the number of parameter value sets ({len(parameter_values)})."
)
for i, (circuit, vector) in enumerate(zip(circuits, parameter_values)):
if len(vector) != circuit.num_parameters:
raise ValueError(
f"The number of values ({len(vector)}) does not match "
f"the number of parameters ({circuit.num_parameters}) for the {i}-th circuit."
)


def _isint(obj: Sequence[Sequence[float]] | Sequence[float] | float) -> bool:
"""Check if object is int."""
int_types = (int, np.integer)
return isinstance(obj, int_types) and not isinstance(obj, bool)


def _isreal(obj: Sequence[Sequence[float]] | Sequence[float] | float) -> bool:
"""Check if object is a real number: int or float except ``±Inf`` and ``NaN``."""
float_types = (float, np.floating)
return _isint(obj) or isinstance(obj, float_types) and float("-Inf") < obj < float("Inf")
return validation._cross_validate_circuits_parameter_values(
circuits, parameter_values=parameter_values
)
44 changes: 8 additions & 36 deletions qiskit/primitives/base/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@
from copy import copy
from typing import Generic, TypeVar

from qiskit.circuit import ControlFlowOp, Measure, QuantumCircuit
from qiskit.utils.deprecation import deprecate_func
from qiskit.circuit import QuantumCircuit
from qiskit.circuit.parametertable import ParameterView
from qiskit.providers import JobV1 as Job

from .base_primitive import BasePrimitive
from . import validation

T = TypeVar("T", bound=Job)

Expand Down Expand Up @@ -130,15 +132,8 @@ def run(
Raises:
ValueError: Invalid arguments are given.
"""
# Singular validation
circuits = self._validate_circuits(circuits)
parameter_values = self._validate_parameter_values(
parameter_values,
default=[()] * len(circuits),
)

# Cross-validation
self._cross_validate_circuits_parameter_values(circuits, parameter_values)
# Validation
circuits, parameter_values = validation._validate_sampler_args(circuits, parameter_values)

# Options
run_opts = copy(self.options)
Expand All @@ -157,27 +152,15 @@ def _run(
parameter_values: tuple[tuple[float, ...], ...],
**run_options,
) -> T:
raise NotImplementedError("The subclass of BaseSampler must implment `_run` method.")
raise NotImplementedError("The subclass of BaseSampler must implement `_run` method.")

@classmethod
@deprecate_func(since="0.46.0")
def _validate_circuits(
cls,
circuits: Sequence[QuantumCircuit] | QuantumCircuit,
) -> tuple[QuantumCircuit, ...]:
circuits = super()._validate_circuits(circuits)
for i, circuit in enumerate(circuits):
if circuit.num_clbits == 0:
raise ValueError(
f"The {i}-th circuit does not have any classical bit. "
"Sampler requires classical bits, plus measurements "
"on the desired qubits."
)
if not _has_measure(circuit):
raise ValueError(
f"The {i}-th circuit does not have Measure instruction. "
"Without measurements, the circuit cannot be sampled from."
)
return circuits
return validation._validate_circuits(circuits, requires_measure=True)

@property
def circuits(self) -> tuple[QuantumCircuit, ...]:
Expand All @@ -196,14 +179,3 @@ def parameters(self) -> tuple[ParameterView, ...]:
List of the parameters in each quantum circuit.
"""
return tuple(self._parameters)


def _has_measure(circuit: QuantumCircuit) -> bool:
for instruction in reversed(circuit):
if isinstance(instruction.operation, Measure):
return True
elif isinstance(instruction.operation, ControlFlowOp):
for block in instruction.operation.blocks:
if _has_measure(block):
return True
return False
Loading

0 comments on commit 6177feb

Please sign in to comment.