diff --git a/jax/_src/api.py b/jax/_src/api.py index 4bf964a72239..409466a978af 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -987,6 +987,10 @@ def vmap_f(*args, **kwargs): f = lu.wrap_init(fun) flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree) in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True) + args_flat = [lax_internal.asarray(arg) + if isinstance(arg, np.ndarray) and ax is not None + else arg + for arg, ax in zip(args_flat, in_axes_flat)] axis_size_ = (axis_size if axis_size is not None else _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap")) try: diff --git a/tests/api_test.py b/tests/api_test.py index 379c6390061a..12bc4f32744e 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1812,8 +1812,21 @@ def f(x): self.assertIsInstance(x, jax.Array) self.assertNotIsInstance(x, np.ndarray) return x + 2 - jit(f)(3) - jax.vmap(f)(np.arange(3)) + result = jit(f)(3) + self.assertIsInstance(result, jax.Array) + + result = jax.vmap(f)(np.arange(3)) + self.assertIsInstance(result, jax.Array) + + def test_numpy_input_with_trivial_function(self): + # Regression test for https://github.com/jax-ml/jax/issues/25745 + def f(x): + return x + jit_result = jit(f)(3) + self.assertIsInstance(jit_result, jax.Array) + + vmap_result = jax.vmap(f)(np.arange(3)) + self.assertIsInstance(vmap_result, jax.Array) def test_device_put_and_get(self): x = np.arange(12.).reshape((3, 4)).astype("float32")