Skip to content

Commit 5855b3d

Browse files
committed
Fix the interactive script.
1 parent 2dffb49 commit 5855b3d

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

run_interactive.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def main(argv):
4646

4747
if profiling_prefill:
4848
jax.profiler.start_trace(profiling_output)
49-
decode_state = engine.init_decode_state()
49+
decode_state = engine.init_decode_state()
50+
if profiling_prefill:
5051
jax.profiler.stop_trace()
5152
prompts: List[str] = [
5253
"I believe the meaning of life is",
@@ -65,11 +66,12 @@ def main(argv):
6566
# pylint: disable-next=all
6667
if profiling_prefill:
6768
jax.profiler.start_trace(profiling_output)
68-
prefill_result, _ = engine.prefill(
69-
params=params, padded_tokens=tokens, true_length=true_length
70-
)
69+
prefill_result, _ = engine.prefill(
70+
params=params, padded_tokens=tokens, true_length=true_length
71+
)
7172
# pylint: disable-next=all
72-
decode_state = engine.insert(prefill_result, decode_state, slot=slot)
73+
decode_state = engine.insert(prefill_result, decode_state, slot=slot)
74+
if profiling_prefill:
7375
jax.profiler.stop_trace()
7476

7577
sampled_tokens_list = []
@@ -78,8 +80,10 @@ def main(argv):
7880
while True:
7981
if profiling_output:
8082
jax.profiler.start_trace(profiling_output)
81-
decode_state, result_tokens = engine.generate(params, decode_state)
82-
result_tokens = result_tokens.convert_to_numpy()
83+
decode_state, result_tokens = engine.generate(params, decode_state)
84+
result_tokens = result_tokens.convert_to_numpy()
85+
86+
if profiling_output:
8387
jax.profiler.stop_trace()
8488
output, complete = token_utils.process_result_tokens(
8589
tokenizer=tokenizer,

0 commit comments

Comments
 (0)