diff --git a/aqt/jax/v2/config.py b/aqt/jax/v2/config.py index 1bde8dd8..da6c1832 100644 --- a/aqt/jax/v2/config.py +++ b/aqt/jax/v2/config.py @@ -216,7 +216,7 @@ def _set_noise_fn( def set_constant_calibration( - cfg: DotGeneral, bound: float = 1.0, bias: float | None = None + cfg: DotGeneral, bound: Union[jnp.ndarray, float], bias: float | None = None ): """Sets the static bound for calibration.""" calibration_cls = functools.partial(