Skip to content

Commit

Permalink
feat: running tests and everything
Browse files Browse the repository at this point in the history
  • Loading branch information
lvjonok committed Sep 3, 2024
1 parent 5596628 commit a2170c5
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 261 deletions.
16 changes: 15 additions & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,21 @@ jobs:
- name: Run pre-commit
run: pre-commit run --all-files --color always --verbose

# TODO: Add a step to run tests
run-tests:
runs-on: ubuntu-latest
needs: ruff-check
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[dev]
- name: Run tests
run: python -m pytest

deploy:
runs-on: ubuntu-latest
Expand Down
10 changes: 10 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: jaxadi
channels:
- conda-forge
dependencies:
- python=3.10
- pip
- pinocchio
- pip:
- robot_descriptions
- jax
21 changes: 21 additions & 0 deletions examples/02_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
This example shows how to convert a matrix multiplication function defined
in CasADi to a JAX-compatible function using the jaxadi library.
It demonstrates defining a function in CasADi, converting it to a JAX function,
and running the compiled function with random input matrices.
"""

import casadi as cs

from jaxadi import convert

# define input variables for the function
x = cs.SX.sym("x", 10, 10)
y = cs.SX.sym("y", 10, 10)
casadi_fn = cs.Function("myfunc", [x, y], [x @ y])

# define jax function from casadi one
jax_fn = convert(casadi_fn, compile=True)

# Run compiled function
jax_fn(cs.np.random.rand(10, 10), cs.np.random.rand(10, 10))
21 changes: 11 additions & 10 deletions examples/03_pinocchio.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import timeit

import casadi as ca
import jax.numpy as jnp
import pinocchio as pin
import pinocchio.casadi as cpin
from robot_descriptions.panda_description import URDF_PATH
import jax.numpy as jnp

from jaxadi import translate, convert
from jaxadi import convert, translate

# Load the Panda robot model
model = pin.buildModelFromUrdf(URDF_PATH)
Expand All @@ -30,19 +32,18 @@
jax_fn = convert(fk, compile=True)

# Evaluate the function performance
import timeit

q_val = ca.np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0, 0])
jax_q_val = jnp.array(q_val)
jax_q_val = jnp.array([[0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0], [0]])

print("Casadi evaluation:")
print(fk(q_val))
print("JAX evaluation:")
print(jax_fn(jax_q_val))

print("Performance comparison:")
print("Casadi evaluation:")
print(timeit.timeit(lambda: fk(q_val), number=100))
# pwease do not run, it will take a lot of time
# print("Performance comparison:")
# print("Casadi evaluation:")
# print(timeit.timeit(lambda: fk(q_val), number=100))

print("JAX evaluation:")
print(timeit.timeit(lambda: jax_fn(jax_q_val), number=100))
# print("JAX evaluation:")
# print(timeit.timeit(lambda: jax_fn(jax_q_val), number=100))
31 changes: 0 additions & 31 deletions examples/codegen.py

This file was deleted.

16 changes: 0 additions & 16 deletions examples/gen.py

This file was deleted.

124 changes: 0 additions & 124 deletions examples/test.py

This file was deleted.

20 changes: 0 additions & 20 deletions examples/test_eval.py

This file was deleted.

2 changes: 1 addition & 1 deletion jaxadi/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ def convert(casadi_fn: Function, compile=False) -> Callable[..., Any]:
jax_fn = declare(jax_str)

if compile:
compile_fn(jax_fn)
compile_fn(jax_fn, casadi_fn)

return jax_fn
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies = ["casadi", "jax"]
readme = "README.md"
requires-python = ">=3.10"

optional-dependencies = { "dev" = ["pre-commit"] }
optional-dependencies = { "dev" = ["pre-commit", "pytest"] }

[project.urls]
homepage = "https://github.com/based-robotics/jaxadi"
Expand Down
Loading

0 comments on commit a2170c5

Please sign in to comment.