Skip to content

Commit

Permalink
fix: densify structural zeros
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Dec 7, 2024
1 parent 6e6c97e commit 0e5b0da
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 0 deletions.
3 changes: 3 additions & 0 deletions jaxadi/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ._graph import translate as graph_translate
from ._expand import translate as expand_translate
from ._compile import compile as compile_fn
from ._preprocess import densify


def convert(casadi_fn: Function, translate=None, compile=False) -> Callable[..., Any]:
Expand All @@ -21,6 +22,8 @@ def convert(casadi_fn: Function, translate=None, compile=False) -> Callable[...,
if translate is None:
translate = graph_translate

casadi_fn = densify(casadi_fn)

jax_str = translate(casadi_fn)
jax_fn = declare(jax_str)

Expand Down
14 changes: 14 additions & 0 deletions jaxadi/_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from casadi import densify as cs_densify
from casadi import Function


def densify(func: Function):
_i = func.sx_in()
_o = func(*_i)
if not isinstance(_o, tuple):
_o = [_o]
_dense_o = []
for i, o in enumerate(_o):
_dense_o.append(cs_densify(o))
_func = Function(func.name(), _i, _dense_o)
return _func
26 changes: 26 additions & 0 deletions tests/test_casadi_equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from jaxadi import convert

from jaxadi import graph_translate
from jaxadi import expand_translate

# Set a fixed seed for reproducibility
np.random.seed(42)

Expand Down Expand Up @@ -35,6 +38,29 @@ def test_simo_trig():
compare_results(casadi_f, jax_f, x_val)


def test_all_zeros():
X = ca.SX.sym("x", 2)
A = np.zeros((2, 2))
Y = ca.jacobian(A @ X, X)

casadi_f = ca.Function("foo", [X], [Y])
jax_f = convert(casadi_f)
x_val = np.random.randn(2, 1)
compare_results(casadi_f, jax_f, x_val)


def test_structural_zeros():
X = ca.SX.sym("x", 2)
A = np.ones((2, 2))
A[1, :] = 0.0
Y = ca.jacobian(A @ X, X)

casadi_f = ca.Function("foo", [X], [ca.densify(Y)])
jax_f = convert(casadi_f, translate=expand_translate)
x_val = np.random.randn(2, 1)
compare_results(casadi_f, jax_f, x_val)


def test_simo_poly():
x = ca.SX.sym("x", 1, 1)
casadi_f = ca.Function("simo_poly", [x], [x**2, x**3, ca.sqrt(x)])
Expand Down

0 comments on commit 0e5b0da

Please sign in to comment.