Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.

Commit 31022d6

Browse files
danielsuocopybara-github
authored andcommitted
[trax] Explicitly set jax_pmap_shmap_merge=False.
`trainer._multi_device_update_fn` uses `jax.pmap` and when `jax_pmap_shmap_merge=True`, `jax.pmap` requires inputs be explicitly sharded as the underlying `jax.jit` expects. This would need to be fixed if `jax_pmap_shmap_merge=True`. PiperOrigin-RevId: 811810947
1 parent 3d4a276 commit 31022d6

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

trax/optimizers/trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
from trax.layers import combinators as cb
3333

3434

35+
# NOTE(dsuo): _multi_device_update_fn is not compatible with
36+
# jax_pmap_shmap_merge=True because `jax.pmap` requires inputs to be explicitly
37+
# sharded as the underlying `jax.jit` expects.
38+
jax.config.update('jax_pmap_shmap_merge', False)
39+
40+
3541
class Trainer:
3642
"""Multi-device accelerated trainer.
3743

0 commit comments

Comments
 (0)