Skip to content

Mjx get_data is really slow - is there JIT'ed version? #161

@jstmn

Description

@jstmn

Hi,

Thanks for the cool work! I'm having a blocking issue however - mjx.get_data(mj_model, batch) is really slow. It seems like jax.jit(jax.vmap(mjx.step, in_axes=(None, 0))) runs in parallel as expected but i'm having trouble finding the equivalent jit version of get_data. Is this function not called in your training work? Or is the performance hit simply ignored?

Timing data:

with n=100 envs, 100 sim steps, no get_data:
# time:          0.12 s
# fps:           849.45
# ms per step:   1.18 ms

With n=100 envs, 100 sim steps, with get_data:
# time:          9.23 s
# fps:           10.84
# ms per step:   92.26 ms
xml = """
<mujoco>
  <worldbody>
    <light name="top" pos="0 0 1"/>
    <body name="box_and_sphere" euler="0 0 -30">
      <joint name="swing" type="hinge" axis="1 -1 0" pos="-.2 -.2 -.2"/>
      <geom name="red_box" type="box" size=".2 .2 .2" rgba="1 0 0 1"/>
      <geom name="green_sphere" pos=".2 .2 .2" size=".1" rgba="0 1 0 1"/>
    </body>
  </worldbody>
</mujoco>
"""

# Make model, data, and renderer
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
mujoco.mj_resetData(mj_model, mj_data)

batch_size = 100
rng = jax.random.PRNGKey(0)
rng = jax.random.split(rng, batch_size)
batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (1,))))(rng)
jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
jit_get_data = jax.jit(jax.vmap(mjx.get_data, in_axes=(None, 0)))

print("Taking jit warm up step")
jit_step(mjx_model, batch)
# mjx.get_data(mj_model, batch) # works
jit_get_data(mj_model, batch) # errors

Error:

jstm ~/Projects/mpc-sfd-conda pixi run python scripts/fps_tests/mujoco_mjx_fps_test.py
Taking jit warm up step
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/jstm/Projects/mpc-sfd-conda/scripts/fps_tests/mujoco_mjx_fps_test.py", line 77, in <module>
    jit_get_data(mj_model, batch)
TypeError: Error interpreting argument to <function get_data at 0x7f71fafa29e0> as an abstract array. The problematic value is of type <class 'mujoco._structs.MjModel'> and was passed to the function at path m.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

related: google-deepmind/mujoco#2718

Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions