Skip to content

Commit 23fb3f2

Browse files
fix Ray engine crashes on multihost when fetching Jax.array from prefill_ray
1 parent 1e08833 commit 23fb3f2

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

jetstream_pt/ray_worker.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -466,9 +466,6 @@ def prefill_ray(
466466
logits = logits[0]
467467

468468
token = np.argmax(logits[true_length - 1])
469-
updated_caches = multihost_utils.process_allgather(
470-
updated_caches, tiled=True
471-
)
472469
prefix = Prefix(token, updated_caches, true_length)
473470
self.prefix_queue.put(prefix, block=False)
474471

@@ -490,7 +487,7 @@ def prefill_ray(
490487
samples_per_slot=1,
491488
)
492489

493-
return prefix, result
490+
return None, result
494491

495492
def _convert_to_np_caches(
496493
self, caches: List[Tuple[jax.Array, jax.Array]]

0 commit comments

Comments
 (0)