From 23fb3f2b3d30af07d0496ea950afdaca82620bd1 Mon Sep 17 00:00:00 2001 From: Xiang Si Date: Mon, 19 Aug 2024 18:30:17 +0000 Subject: [PATCH] fix Ray engine crashes on multihost when fetching Jax.array from prefill_ray --- jetstream_pt/ray_worker.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index 1a4f15e..cbd0da5 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -466,9 +466,6 @@ def prefill_ray( logits = logits[0] token = np.argmax(logits[true_length - 1]) - updated_caches = multihost_utils.process_allgather( - updated_caches, tiled=True - ) prefix = Prefix(token, updated_caches, true_length) self.prefix_queue.put(prefix, block=False) @@ -490,7 +487,7 @@ def prefill_ray( samples_per_slot=1, ) - return prefix, result + return None, result def _convert_to_np_caches( self, caches: List[Tuple[jax.Array, jax.Array]]