File tree 1 file changed +2
-3
lines changed
1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -357,7 +357,7 @@ def _call_model_generate(
357
357
)
358
358
for k , v in torchjax .to_torch (caches )
359
359
]
360
- mask = jnp .expand_dims (mask , (1 , 2 ))
360
+ mask = jnp .expand_dims (new_mask , (1 , 2 ))
361
361
362
362
args = (tokens , input_pos , caches_obj , mask )
363
363
paramst , argst = torchjax .to_torch ((weights , args ))
@@ -371,7 +371,6 @@ def _call_model_generate(
371
371
new_current_position = (
372
372
current_position + 1
373
373
) % self .env .cache_sequence_length
374
-
375
374
return torchjax .from_torch (
376
375
(
377
376
res ,
@@ -816,7 +815,7 @@ def generate(
816
815
length_idx = (2 * length , 2 * length + 1 ),
817
816
samples_per_slot = 1 ,
818
817
)
819
-
818
+ next_token = jax . lax . with_sharding_constraint ( next_token , self . replicated )
820
819
new_decode_state = DecodeState (
821
820
next_token ,
822
821
new_caches ,
You can’t perform that action at this time.
0 commit comments