File tree 1 file changed +2
-5
lines changed
1 file changed +2
-5
lines changed Original file line number Diff line number Diff line change @@ -454,7 +454,7 @@ def prefill_ray(
454
454
existing_prefix : Optional [Prefix ] = None ,
455
455
padded_tokens : PrefillInputs , # PrefillInputs[np.ndarray],
456
456
true_length : int ,
457
- ) -> tuple [Prefix , engine_api .ResultTokens ]:
457
+ ) -> tuple [Prefix , engine_api .ResultTokens ]:
458
458
"""Do prefill in ray worker"""
459
459
logits , updated_caches = self .prefill (
460
460
params = params ,
@@ -466,9 +466,6 @@ def prefill_ray(
466
466
logits = logits [0 ]
467
467
468
468
token = np .argmax (logits [true_length - 1 ])
469
- updated_caches = multihost_utils .process_allgather (
470
- updated_caches , tiled = True
471
- )
472
469
prefix = Prefix (token , updated_caches , true_length )
473
470
self .prefix_queue .put (prefix , block = False )
474
471
@@ -490,7 +487,7 @@ def prefill_ray(
490
487
samples_per_slot = 1 ,
491
488
)
492
489
493
- return prefix , result
490
+ return None , result
494
491
495
492
def _convert_to_np_caches (
496
493
self , caches : List [Tuple [jax .Array , jax .Array ]]
You can’t perform that action at this time.
0 commit comments