Skip to content

Commit eaf393c

Browse files
committed
Update Jetstream, add optional sampler args.
1 parent 7cbd9ec commit eaf393c

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

deps/JetStream

jetstream_pt/engine.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Implement Jet Engine API."""
1616

17-
from typing import Any, List, Optional, Tuple, Union
17+
from typing import Any, List, Optional, Tuple, Union, Callable
1818
import threading
1919
import functools
2020
import os
@@ -256,6 +256,7 @@ def prefill(
256256
existing_prefix: Optional[Prefix] = None,
257257
padded_tokens: PrefillInputs, # PrefillInputs[jax.Array],
258258
true_length: int,
259+
sampler: Optional[Callable[[Any], Any]] = None,
259260
) -> Tuple[Prefix, engine_api.ResultTokens]:
260261
if isinstance(padded_tokens, jax.Array):
261262
batched_token = padded_tokens.reshape(1, -1)
@@ -273,14 +274,17 @@ def prefill(
273274
)
274275
if len(logits.shape) == 3: # b, seqlen, num words
275276
logits = logits[0] # seqlen, num words
276-
token = sampling_utils.sampling(
277-
logits[true_length - 1],
278-
self.rng,
279-
self.env.sampling_algorithm,
280-
self.env.topk,
281-
self.env.nucleus_topp,
282-
self.env.temperature,
283-
)
277+
if sampler:
278+
token = sampler(logits[true_length - 1])
279+
else:
280+
token = sampling_utils.sampling(
281+
logits[true_length - 1],
282+
self.rng,
283+
self.env.sampling_algorithm,
284+
self.env.topk,
285+
self.env.nucleus_topp,
286+
self.env.temperature,
287+
)
284288
token_out = jnp.reshape(token, (1, 1))
285289
data = jnp.concatenate(
286290
[
@@ -610,7 +614,7 @@ def false_comp(b, i, bk, start, end):
610614
return b_next, i_next
611615

612616
def generate(
613-
self, params: Any, decode_state: DecodeState
617+
self, params: Any, decode_state: DecodeState, sampler = None
614618
) -> tuple[DecodeState, engine_api.ResultTokens]:
615619
# seq_len = padded_tokens.shape[0]
616620
pos = decode_state.current_position
@@ -653,7 +657,10 @@ def update_mask():
653657
# fill mask later, now use flash attention
654658
mask = update_mask()
655659

656-
next_token = self._sampling(logits, self.env.batch_size)
660+
if sampler:
661+
next_token = sampler(logits[:, -1])
662+
else:
663+
next_token = self._sampling(logits, self.env.batch_size)
657664
if self.env.ring_buffer:
658665
input_pos = decode_state.input_pos + 1
659666
lens = decode_state.lens + 1

0 commit comments

Comments
 (0)