Skip to content

Commit

Permalink
fix: remove legacy translator
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Sep 9, 2024
1 parent 9ec5d1a commit bdeed91
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 54 deletions.
4 changes: 2 additions & 2 deletions examples/00_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import casadi as cs

from jaxadi import translate, legacy_translate, translate_one_line
from jaxadi import translate

# define input variables for the function
x = cs.SX.sym("x", 3)
Expand All @@ -20,4 +20,4 @@

print("Translated JAX function:")
# secure add_import and add_jit to True to get the complete code
print(translate_one_line(casadi_function, add_import=True, add_jit=True))
print(translate(casadi_function, add_import=True, add_jit=True))
3 changes: 1 addition & 2 deletions examples/02_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
import casadi as cs

from jaxadi import convert
from jaxadi import translate_one_line

# define input variables for the function
x = cs.SX.sym("x", 10, 10)
y = cs.SX.sym("y", 10, 10)
casadi_fn = cs.Function("myfunc", [x, y], [x @ y])

# define jax function from casadi one
jax_fn = convert(casadi_fn, compile=True, translator=translate_one_line)
jax_fn = convert(casadi_fn, compile=True)

# Run compiled function
jax_fn(cs.np.random.rand(10, 10), cs.np.random.rand(10, 10))
50 changes: 0 additions & 50 deletions jaxadi/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,56 +48,6 @@
OP_TWICE,
)

OP_JAX_DICT = {
OP_ASSIGN: "\n work = work.at[{0}].set(work[{1}])",
OP_ADD: "\n work = work.at[{0}].set(work[{1}] + work[{2}])",
OP_SUB: "\n work = work.at[{0}].set(work[{1}] - work[{2}])",
OP_MUL: "\n work = work.at[{0}].set(work[{1}] * work[{2}])",
OP_DIV: "\n work = work.at[{0}].set(work[{1}] / work[{2}])",
OP_NEG: "\n work = work.at[{0}].set(-work[{1}])",
OP_EXP: "\n work = work.at[{0}].set(jnp.exp(work[{1}]))",
OP_LOG: "\n work = work.at[{0}].set(jnp.log(work[{1}]))",
OP_POW: "\n work = work.at[{0}].set(jnp.power(work[{1}], work[{2}]))",
OP_CONSTPOW: "\n work = work.at[{0}].set(jnp.power(work[{1}], work[{2}]))",
OP_SQRT: "\n work = work.at[{0}].set(jnp.sqrt(work[{1}]))",
OP_SQ: "\n work = work.at[{0}].set(work[{1}] * work[{2}])",
OP_TWICE: "\n work = work.at[{0}].set(2 * work[{1}])",
OP_SIN: "\n work = work.at[{0}].set(jnp.sin(work[{1}]))",
OP_COS: "\n work = work.at[{0}].set(jnp.cos(work[{1}]))",
OP_TAN: "\n work = work.at[{0}].set(jnp.tan(work[{1}]))",
OP_ASIN: "\n work = work.at[{0}].set(jnp.arcsin(work[{1}]))",
OP_ACOS: "\n work = work.at[{0}].set(jnp.arccos(work[{1}]))",
OP_ATAN: "\n work = work.at[{0}].set(jnp.arctan(work[{1}]))",
OP_LT: "\n work = work.at[{0}].set(work[{1}] < work[{2}])",
OP_LE: "\n work = work.at[{0}].set(work[{1}] <= work[{2}])",
OP_EQ: "\n work = work.at[{0}].set(work[{1}] == work[{2}])",
OP_NE: "\n work = work.at[{0}].set(work[{1}] != work[{2}])",
OP_NOT: "\n work = work.at[{0}].set(jnp.logical_not(work[{1}]))",
OP_AND: "\n work = work.at[{0}].set(jnp.logical_and(work[{1}], work[{2}]))",
OP_OR: "\n work = work.at[{0}].set(jnp.logical_or(work[{1}], work[{2}]))",
OP_FLOOR: "\n work = work.at[{0}].set(jnp.floor(work[{1}]))",
OP_CEIL: "\n work = work.at[{0}].set(jnp.ceil(work[{1}]))",
OP_FMOD: "\n work = work.at[{0}].set(jnp.fmod(work[{1}], work[{2}]))",
OP_FABS: "\n work = work.at[{0}].set(jnp.abs(work[{1}]))",
OP_SIGN: "\n work = work.at[{0}].set(jnp.sign(work[{1}]))",
OP_COPYSIGN: "\n work = work.at[{0}].set(jnp.copysign(work[{1}], work[{2}]))",
OP_IF_ELSE_ZERO: "\n work = work.at[{0}].set(jnp.where(work[{1}] == 0, 0, work[{2}]))",
OP_ERF: "\n work = work.at[{0}].set(jax.scipy.special.erf(work[{1}]))",
OP_FMIN: "\n work = work.at[{0}].set(jnp.minimum(work[{1}], work[{2}]))",
OP_FMAX: "\n work = work.at[{0}].set(jnp.maximum(work[{1}], work[{2}]))",
OP_INV: "\n work = work.at[{0}].set(1.0 / work[{1}])",
OP_SINH: "\n work = work.at[{0}].set(jnp.sinh(work[{1}]))",
OP_COSH: "\n work = work.at[{0}].set(jnp.cosh(work[{1}]))",
OP_TANH: "\n work = work.at[{0}].set(jnp.tanh(work[{1}]))",
OP_ASINH: "\n work = work.at[{0}].set(jnp.arcsinh(work[{1}]))",
OP_ACOSH: "\n work = work.at[{0}].set(jnp.arccosh(work[{1}]))",
OP_ATANH: "\n work = work.at[{0}].set(jnp.arctanh(work[{1}]))",
OP_ATAN2: "\n work = work.at[{0}].set(jnp.arctan2(work[{1}], work[{2}]))",
OP_CONST: "\n work = work.at[{0}].set({1:.16f})",
OP_INPUT: "\n work = work.at[{0}].set(inputs[{1}][{2}, {3}])",
OP_OUTPUT: "\n outputs[{0}] = outputs[{0}].at[{1}, {2}].set(work[{3}][0])",
}

OP_JAX_VALUE_DICT = {
OP_ASSIGN: "work[{0}]",
OP_ADD: "work[{0}] + work[{1}]",
Expand Down

0 comments on commit bdeed91

Please sign in to comment.