This repository was archived by the owner on Oct 31, 2025. It is now read-only.
Commit 31022d6
[trax] Explicitly set
`trainer._multi_device_update_fn` uses `jax.pmap` and when `jax_pmap_shmap_merge=True`, `jax.pmap` requires inputs be explicitly sharded as the underlying `jax.jit` expects.
This would need to be fixed if `jax_pmap_shmap_merge=True`.
PiperOrigin-RevId: 811810947jax_pmap_shmap_merge=False.1 parent 3d4a276 commit 31022d6
1 file changed
+6
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
35 | 41 | | |
36 | 42 | | |
37 | 43 | | |
| |||
0 commit comments