We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1e08833 commit 23fb3f2Copy full SHA for 23fb3f2
jetstream_pt/ray_worker.py
@@ -466,9 +466,6 @@ def prefill_ray(
466
logits = logits[0]
467
468
token = np.argmax(logits[true_length - 1])
469
- updated_caches = multihost_utils.process_allgather(
470
- updated_caches, tiled=True
471
- )
472
prefix = Prefix(token, updated_caches, true_length)
473
self.prefix_queue.put(prefix, block=False)
474
@@ -490,7 +487,7 @@ def prefill_ray(
490
487
samples_per_slot=1,
491
488
)
492
489
493
- return prefix, result
+ return None, result
494
495
def _convert_to_np_caches(
496
self, caches: List[Tuple[jax.Array, jax.Array]]
0 commit comments