14
14
15
15
"""Implement Jet Engine API."""
16
16
17
- from typing import Any , List , Optional , Tuple , Union
17
+ from typing import Any , List , Optional , Tuple , Union , Callable
18
18
import threading
19
19
import functools
20
20
import os
@@ -256,6 +256,7 @@ def prefill(
256
256
existing_prefix : Optional [Prefix ] = None ,
257
257
padded_tokens : PrefillInputs , # PrefillInputs[jax.Array],
258
258
true_length : int ,
259
+ sampler : Optional [Callable [[Any ], Any ]] = None ,
259
260
) -> Tuple [Prefix , engine_api .ResultTokens ]:
260
261
if isinstance (padded_tokens , jax .Array ):
261
262
batched_token = padded_tokens .reshape (1 , - 1 )
@@ -273,14 +274,17 @@ def prefill(
273
274
)
274
275
if len (logits .shape ) == 3 : # b, seqlen, num words
275
276
logits = logits [0 ] # seqlen, num words
276
- token = sampling_utils .sampling (
277
- logits [true_length - 1 ],
278
- self .rng ,
279
- self .env .sampling_algorithm ,
280
- self .env .topk ,
281
- self .env .nucleus_topp ,
282
- self .env .temperature ,
283
- )
277
+ if sampler :
278
+ token = sampler (logits [true_length - 1 ])
279
+ else :
280
+ token = sampling_utils .sampling (
281
+ logits [true_length - 1 ],
282
+ self .rng ,
283
+ self .env .sampling_algorithm ,
284
+ self .env .topk ,
285
+ self .env .nucleus_topp ,
286
+ self .env .temperature ,
287
+ )
284
288
token_out = jnp .reshape (token , (1 , 1 ))
285
289
data = jnp .concatenate (
286
290
[
@@ -610,7 +614,7 @@ def false_comp(b, i, bk, start, end):
610
614
return b_next , i_next
611
615
612
616
def generate (
613
- self , params : Any , decode_state : DecodeState
617
+ self , params : Any , decode_state : DecodeState , sampler = None
614
618
) -> tuple [DecodeState , engine_api .ResultTokens ]:
615
619
# seq_len = padded_tokens.shape[0]
616
620
pos = decode_state .current_position
@@ -653,7 +657,10 @@ def update_mask():
653
657
# fill mask later, now use flash attention
654
658
mask = update_mask ()
655
659
656
- next_token = self ._sampling (logits , self .env .batch_size )
660
+ if sampler :
661
+ next_token = sampler (logits [:, - 1 ])
662
+ else :
663
+ next_token = self ._sampling (logits , self .env .batch_size )
657
664
if self .env .ring_buffer :
658
665
input_pos = decode_state .input_pos + 1
659
666
lens = decode_state .lens + 1
0 commit comments