diff --git a/jaxadi/_ops.py b/jaxadi/_ops.py index fcdc8d5..86cc970 100644 --- a/jaxadi/_ops.py +++ b/jaxadi/_ops.py @@ -94,6 +94,6 @@ OP_ATANH: "jnp.arctanh(work[{0}])", OP_ATAN2: "jnp.arctan2(work[{0}], work[{1}])", OP_CONST: "{0:.16f}", - OP_INPUT: "inputs[{0}, {1}, {2}]", + OP_INPUT: "inputs[{0}][{1}, {2}]", OP_OUTPUT: "work[{0}][0]", } diff --git a/jaxadi/_translate.py b/jaxadi/_translate.py index b617cf9..676b260 100644 --- a/jaxadi/_translate.py +++ b/jaxadi/_translate.py @@ -18,7 +18,7 @@ def translate(func: Function, add_jit=False, add_import=False) -> str: codegen += "@jax.jit\n" if add_jit else "" codegen += f"def evaluate_{func.name()}(*args):\n" # combine all inputs into a single list - codegen += " inputs = jnp.expand_dims(jnp.array(args), axis=-1)\n" + codegen += " inputs = [jnp.expand_dims(jnp.array(arg), axis=-1) for arg in args]\n" # output variables codegen += f" outputs = [jnp.zeros(out) for out in {out_shapes}]\n" diff --git a/tests/test_input.py b/tests/test_input.py new file mode 100644 index 0000000..372244a --- /dev/null +++ b/tests/test_input.py @@ -0,0 +1,18 @@ +import casadi as cs +import jax.numpy as jnp +import numpy as np + +from jaxadi import convert + + +def test_different_shapes(): + x = cs.SX.sym("x", 2, 3) + y = cs.SX.sym("y", 3, 2) + casadi_fn = cs.Function("myfunc", [x, y], [x @ y]) + + jax_fn = convert(casadi_fn, compile=True) + + in1 = jnp.array(np.random.randn(2, 3)) + in2 = jnp.array(np.random.randn(3, 2)) + + jax_fn(in1, in2)