diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index 1a4f15e..cbd0da5 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -466,9 +466,6 @@ def prefill_ray( logits = logits[0] token = np.argmax(logits[true_length - 1]) - updated_caches = multihost_utils.process_allgather( - updated_caches, tiled=True - ) prefix = Prefix(token, updated_caches, true_length) self.prefix_queue.put(prefix, block=False) @@ -490,7 +487,7 @@ def prefill_ray( samples_per_slot=1, ) - return prefix, result + return None, result def _convert_to_np_caches( self, caches: List[Tuple[jax.Array, jax.Array]]