-
Notifications
You must be signed in to change notification settings - Fork 119
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Recent version of JAX (jax==0.6.0 in this case) breaks the jax.lax.scan in the GAE and system loss calculations. The reason for this is that initial carry values are not explicitly put on the learner device leading to type mismatches.
To Reproduce
Steps to reproduce the behavior:
uv sync --extra cuda12python mava/systems/ppo/sebulba/ff_ippo.py env=lbf_gym
Expected behavior
The system will fail with an error message along these lines:
Traceback (most recent call last):
File "Mava/mava/systems/ppo/sebulba/ff_ippo.py", line 400, in learner_thread
learner_state, train_metrics = learn_fn(learner_state, traj_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "Mava/mava/systems/ppo/sebulba/ff_ippo.py", line 366, in learner_fn
learner_state, loss_info = _update_step(learner_state, traj_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "Mava/mava/systems/ppo/sebulba/ff_ippo.py", line 196, in _update_step
advantages, targets = calculate_gae(
^^^^^^^^^^^^^^
File "Mava/mava/utils/multistep.py", line 61, in calculate_gae
_, advantages = jax.lax.scan(
^^^^^^^^^^^^^
TypeError: scan body function carry input and carry output must have equal types, but they differ:
The input carry component carry[0] has type float32[32,2] but the corresponding output carry component has type float32[32,2]{learner_devices}, so the varying manual axes do not match.
This might be fixed by applying `jax.lax.pvary(..., ('learner_devices',))` to the initial carry value corresponding to the input carry component carry[0].
Revise the function so that all output types match the corresponding input types.
Additional context
A potential solution is discussed here we just need to implement it.
Possible Solution
Implement the solution discussed above. In the mean time, we can just pin to jax==0.5.2.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working