diff --git a/brax/contact.py b/brax/contact.py index 6afe2362..757824b1 100644 --- a/brax/contact.py +++ b/brax/contact.py @@ -35,7 +35,8 @@ def get(sys: System, x: Transform) -> Optional[Contact]: Returns: Contact pytree """ - ncon = mjx.ncon(sys) + # TODO: use mjx.ncon. + ncon = mjx._src.collision_driver.ncon(sys) if not ncon: return None diff --git a/brax/envs/env_test.py b/brax/envs/env_test.py index 87ddabd4..ee15d7ce 100644 --- a/brax/envs/env_test.py +++ b/brax/envs/env_test.py @@ -57,16 +57,6 @@ def testSpeed(self, backend, env_name, expected_sps): ) self.assertGreater(mean_sps, expected_sps * 0.99) - @parameterized.parameters(['mjx', 'generalized', 'spring', 'positional']) - def test_render(self, backend): - env = envs.create( - 'ant', - backend=backend, - ) - state = jax.jit(env.reset)(jax.random.PRNGKey(0)) - images = env.render([state.pipeline_state]) - self.assertLen(images, 1) - self.assertEqual(images[0].shape, (240, 320, 3)) if __name__ == '__main__':