diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 9b0f6e4d..8a091b3f 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -70,7 +70,7 @@ def prefill( existing_prefix: Optional[Prefix] = None, padded_tokens: np.ndarray, # PrefillInputs[np.ndarray], true_length: int, - ) -> Prefix: + ) -> Tuple[Prefix, engine_api.ResultTokens]: if self.is_disaggregated: return self.prefill_impl( params=params, @@ -95,7 +95,7 @@ def prefill_impl( existing_prefix: Optional[Prefix] = None, padded_tokens: np.ndarray, # PrefillInputs[np.ndarray], true_length: int, - ) -> Prefix: + ) -> Tuple[Prefix, engine_api.ResultTokens]: all_outputs = [] for worker in self.engine_workers: prefill_func = ( diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index 25857f75..d0d8d836 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -461,7 +461,7 @@ def prefill_ray( existing_prefix: Optional[Prefix] = None, padded_tokens: PrefillInputs, # PrefillInputs[np.ndarray], true_length: int, - ) -> None: + ) -> tuple[Prefix, engine_api.ResultTokens]: """Do prefill in ray worker""" logits, updated_caches = self.prefill( params=params, @@ -476,7 +476,25 @@ def prefill_ray( prefix = Prefix(token, updated_caches, true_length) self.prefix_queue.put(prefix, block=False) - return token + token_out = jnp.reshape(token, (1, 1)) + data = jnp.concatenate( + [ + token_out, # First token + jnp.ones_like(token_out), # validity of first token + jnp.zeros((1, 1), dtype=jnp.int32), # length = 0 + ], + axis=-1, + ) + length = token_out.shape[1] + result = engine_api.ResultTokens( + data=data, + tokens_idx=(0, length), + valid_idx=(length, 2 * length), + length_idx=(2 * length, 2 * length + 1), + samples_per_slot=1, + ) + + return prefix, result def _convert_to_np_caches( self, caches: List[Tuple[jax.Array, jax.Array]] @@ -495,7 +513,7 @@ def prefill_ray_disaggregation( existing_prefix: Optional[Prefix] = None, padded_tokens: PrefillInputs, # PrefillInputs[np.ndarray], true_length: int, - ) -> Any: + ) -> tuple[NpPrefix, engine_api.ResultTokens]: """Do prefill in ray worker""" logits, updated_caches = self.prefill( params=params, @@ -513,7 +531,25 @@ def prefill_ray_disaggregation( np_update_caches = self._convert_to_np_caches(updated_caches) np_prefix = NpPrefix(token, np_update_caches, true_length) - return np_prefix + token_out = jnp.reshape(token, (1, 1)) + data = jnp.concatenate( + [ + token_out, # First token + jnp.ones_like(token_out), # validity of first token + jnp.zeros((1, 1), dtype=jnp.int32), # length = 0 + ], + axis=-1, + ) + length = token_out.shape[1] + result = engine_api.ResultTokens( + data=data, + tokens_idx=(0, length), + valid_idx=(length, 2 * length), + length_idx=(2 * length, 2 * length + 1), + samples_per_slot=1, + ) + + return np_prefix, result def transfer(self, np_prefix: NpPrefix) -> Any: """Transfer prefill result from object store to HBM"""