Skip to content

Commit ce28303

Browse files
fix Ray engine crashes on multihost when fetching Jax.array from prefill_ray
1 parent 1e08833 commit ce28303

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

jetstream_pt/ray_worker.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def prefill_ray(
454454
existing_prefix: Optional[Prefix] = None,
455455
padded_tokens: PrefillInputs, # PrefillInputs[np.ndarray],
456456
true_length: int,
457-
) -> tuple[Prefix, engine_api.ResultTokens]:
457+
) -> tuple[Prefix, engine_api.ResultTokens]:
458458
"""Do prefill in ray worker"""
459459
logits, updated_caches = self.prefill(
460460
params=params,
@@ -466,9 +466,6 @@ def prefill_ray(
466466
logits = logits[0]
467467

468468
token = np.argmax(logits[true_length - 1])
469-
updated_caches = multihost_utils.process_allgather(
470-
updated_caches, tiled=True
471-
)
472469
prefix = Prefix(token, updated_caches, true_length)
473470
self.prefix_queue.put(prefix, block=False)
474471

@@ -490,7 +487,7 @@ def prefill_ray(
490487
samples_per_slot=1,
491488
)
492489

493-
return prefix, result
490+
return None, result
494491

495492
def _convert_to_np_caches(
496493
self, caches: List[Tuple[jax.Array, jax.Array]]

0 commit comments

Comments
 (0)