Skip to content

Commit d84ae15

Browse files
Fix ray recompilation and accuracy (#189)
1 parent f2e5181 commit d84ae15

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

jetstream_pt/ray_worker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def _call_model_generate(
357357
)
358358
for k, v in torchjax.to_torch(caches)
359359
]
360-
mask = jnp.expand_dims(mask, (1, 2))
360+
mask = jnp.expand_dims(new_mask, (1, 2))
361361

362362
args = (tokens, input_pos, caches_obj, mask)
363363
paramst, argst = torchjax.to_torch((weights, args))
@@ -371,7 +371,6 @@ def _call_model_generate(
371371
new_current_position = (
372372
current_position + 1
373373
) % self.env.cache_sequence_length
374-
375374
return torchjax.from_torch(
376375
(
377376
res,
@@ -816,7 +815,7 @@ def generate(
816815
length_idx=(2 * length, 2 * length + 1),
817816
samples_per_slot=1,
818817
)
819-
818+
next_token = jax.lax.with_sharding_constraint(next_token, self.replicated)
820819
new_decode_state = DecodeState(
821820
next_token,
822821
new_caches,

0 commit comments

Comments
 (0)