Skip to content

Commit 0572532

Browse files
authored
Return np instead of jax array for prefill result tokens (#158)
return np instead of jax array for prefill result tokens
1 parent 60c2fa5 commit 0572532

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

jetstream_pt/ray_worker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -469,12 +469,12 @@ def prefill_ray(
469469
prefix = Prefix(token, updated_caches, true_length)
470470
self.prefix_queue.put(prefix, block=False)
471471

472-
token_out = jnp.reshape(token, (1, 1))
473-
data = jnp.concatenate(
472+
token_out = np.reshape(token, (1, 1))
473+
data = np.concatenate(
474474
[
475475
token_out, # First token
476-
jnp.ones_like(token_out), # validity of first token
477-
jnp.zeros((1, 1), dtype=jnp.int32), # length = 0
476+
np.ones_like(token_out), # validity of first token
477+
np.zeros((1, 1), dtype=np.int32), # length = 0
478478
],
479479
axis=-1,
480480
)

0 commit comments

Comments
 (0)