-
Notifications
You must be signed in to change notification settings - Fork 256
Closed
Description
Hi,
Thanks for the cool work! I'm having a blocking issue however - mjx.get_data(mj_model, batch) is really slow. It seems like jax.jit(jax.vmap(mjx.step, in_axes=(None, 0))) runs in parallel as expected but i'm having trouble finding the equivalent jit version of get_data. Is this function not called in your training work? Or is the performance hit simply ignored?
Timing data:
with n=100 envs, 100 sim steps, no get_data:
# time: 0.12 s
# fps: 849.45
# ms per step: 1.18 ms
With n=100 envs, 100 sim steps, with get_data:
# time: 9.23 s
# fps: 10.84
# ms per step: 92.26 ms
xml = """
<mujoco>
<worldbody>
<light name="top" pos="0 0 1"/>
<body name="box_and_sphere" euler="0 0 -30">
<joint name="swing" type="hinge" axis="1 -1 0" pos="-.2 -.2 -.2"/>
<geom name="red_box" type="box" size=".2 .2 .2" rgba="1 0 0 1"/>
<geom name="green_sphere" pos=".2 .2 .2" size=".1" rgba="0 1 0 1"/>
</body>
</worldbody>
</mujoco>
"""
# Make model, data, and renderer
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
mujoco.mj_resetData(mj_model, mj_data)
batch_size = 100
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, batch_size)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (1,))))(rng)
jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
jit_get_data = jax.jit(jax.vmap(mjx.get_data, in_axes=(None, 0)))
print("Taking jit warm up step")
jit_step(mjx_model, batch)
# mjx.get_data(mj_model, batch) # works
jit_get_data(mj_model, batch) # errorsError:
jstm ~/Projects/mpc-sfd-conda pixi run python scripts/fps_tests/mujoco_mjx_fps_test.py
Taking jit warm up step
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/jstm/Projects/mpc-sfd-conda/scripts/fps_tests/mujoco_mjx_fps_test.py", line 77, in <module>
jit_get_data(mj_model, batch)
TypeError: Error interpreting argument to <function get_data at 0x7f71fafa29e0> as an abstract array. The problematic value is of type <class 'mujoco._structs.MjModel'> and was passed to the function at path m.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.
related: google-deepmind/mujoco#2718
Thanks
Metadata
Metadata
Assignees
Labels
No labels