diff --git a/test/model_test.py b/test/model_test.py index 6f74dbc4..2e1ece72 100644 --- a/test/model_test.py +++ b/test/model_test.py @@ -89,7 +89,7 @@ class MjxModelsTest(parameterized.TestCase): @parameterized.parameters(_MJX_MODEL_XMLS) def test_compiles_and_steps(self, xml_path: pathlib.Path) -> None: model = mujoco.MjModel.from_xml_path(str(xml_path)) - model = mjx.device_put(model) + model = mjx.put_model(model) data = mjx.make_data(model) ctrlrange = jp.where( model.actuator_ctrllimited[:, None],