diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index 2b88055c..01b647db 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -469,12 +469,12 @@ def prefill_ray( prefix = Prefix(token, updated_caches, true_length) self.prefix_queue.put(prefix, block=False) - token_out = jnp.reshape(token, (1, 1)) - data = jnp.concatenate( + token_out = np.reshape(token, (1, 1)) + data = np.concatenate( [ token_out, # First token - jnp.ones_like(token_out), # validity of first token - jnp.zeros((1, 1), dtype=jnp.int32), # length = 0 + np.ones_like(token_out), # validity of first token + np.zeros((1, 1), dtype=np.int32), # length = 0 ], axis=-1, )