Skip to content

[BUG] jax==0.6.0 causes type mismatches between jax.lax.scan outputs and initial carry values #1184

@RuanJohn

Description

@RuanJohn

Describe the bug

Recent version of JAX (jax==0.6.0 in this case) breaks the jax.lax.scan in the GAE and system loss calculations. The reason for this is that initial carry values are not explicitly put on the learner device leading to type mismatches.

To Reproduce

Steps to reproduce the behavior:

  1. uv sync --extra cuda12
  2. python mava/systems/ppo/sebulba/ff_ippo.py env=lbf_gym

Expected behavior

The system will fail with an error message along these lines:

Traceback (most recent call last):
  File "Mava/mava/systems/ppo/sebulba/ff_ippo.py", line 400, in learner_thread
    learner_state, train_metrics = learn_fn(learner_state, traj_batch)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "Mava/mava/systems/ppo/sebulba/ff_ippo.py", line 366, in learner_fn
    learner_state, loss_info = _update_step(learner_state, traj_batch)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "Mava/mava/systems/ppo/sebulba/ff_ippo.py", line 196, in _update_step
    advantages, targets = calculate_gae(
                          ^^^^^^^^^^^^^^
  File "Mava/mava/utils/multistep.py", line 61, in calculate_gae
    _, advantages = jax.lax.scan(
                    ^^^^^^^^^^^^^
TypeError: scan body function carry input and carry output must have equal types, but they differ:

The input carry component carry[0] has type float32[32,2] but the corresponding output carry component has type float32[32,2]{learner_devices}, so the varying manual axes do not match.

This might be fixed by applying `jax.lax.pvary(..., ('learner_devices',))` to the initial carry value corresponding to the input carry component carry[0].

Revise the function so that all output types match the corresponding input types.

Additional context

A potential solution is discussed here we just need to implement it.

Possible Solution

Implement the solution discussed above. In the mean time, we can just pin to jax==0.5.2.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions