Skip to content

Commit 62f3c51

Browse files
richardsliuwang2yn84
authored andcommitted
Fix Ray engine crash on multihost (#164)
1 parent d14e7f5 commit 62f3c51

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

jetstream_pt/ray_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,9 @@ def prefill_ray(
466466
logits = logits[0]
467467

468468
token = np.argmax(logits[true_length - 1])
469+
updated_caches = multihost_utils.process_allgather(
470+
updated_caches, tiled=True
471+
)
469472
prefix = Prefix(token, updated_caches, true_length)
470473
self.prefix_queue.put(prefix, block=False)
471474

0 commit comments

Comments
 (0)