Skip to content

Commit

Permalink
jax.vmap: convert mapped input arguments to array
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 10, 2025
1 parent c1de7c7 commit de0972c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
4 changes: 4 additions & 0 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,10 @@ def vmap_f(*args, **kwargs):
f = lu.wrap_init(fun)
flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
args_flat = [lax_internal.asarray(arg)
if isinstance(arg, np.ndarray) and ax is not None
else arg
for arg, ax in zip(args_flat, in_axes_flat)]
axis_size_ = (axis_size if axis_size is not None else
_mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap"))
try:
Expand Down
17 changes: 15 additions & 2 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,8 +1812,21 @@ def f(x):
self.assertIsInstance(x, jax.Array)
self.assertNotIsInstance(x, np.ndarray)
return x + 2
jit(f)(3)
jax.vmap(f)(np.arange(3))
result = jit(f)(3)
self.assertIsInstance(result, jax.Array)

result = jax.vmap(f)(np.arange(3))
self.assertIsInstance(result, jax.Array)

def test_numpy_input_with_trivial_function(self):
# Regression test for https://github.com/jax-ml/jax/issues/25745
def f(x):
return x
jit_result = jit(f)(3)
self.assertIsInstance(jit_result, jax.Array)

vmap_result = jax.vmap(f)(np.arange(3))
self.assertIsInstance(vmap_result, jax.Array)

def test_device_put_and_get(self):
x = np.arange(12.).reshape((3, 4)).astype("float32")
Expand Down

0 comments on commit de0972c

Please sign in to comment.