@@ -46,7 +46,8 @@ def main(argv):
46
46
47
47
if profiling_prefill :
48
48
jax .profiler .start_trace (profiling_output )
49
- decode_state = engine .init_decode_state ()
49
+ decode_state = engine .init_decode_state ()
50
+ if profiling_prefill :
50
51
jax .profiler .stop_trace ()
51
52
prompts : List [str ] = [
52
53
"I believe the meaning of life is" ,
@@ -65,11 +66,12 @@ def main(argv):
65
66
# pylint: disable-next=all
66
67
if profiling_prefill :
67
68
jax .profiler .start_trace (profiling_output )
68
- prefill_result , _ = engine .prefill (
69
- params = params , padded_tokens = tokens , true_length = true_length
70
- )
69
+ prefill_result , _ = engine .prefill (
70
+ params = params , padded_tokens = tokens , true_length = true_length
71
+ )
71
72
# pylint: disable-next=all
72
- decode_state = engine .insert (prefill_result , decode_state , slot = slot )
73
+ decode_state = engine .insert (prefill_result , decode_state , slot = slot )
74
+ if profiling_prefill :
73
75
jax .profiler .stop_trace ()
74
76
75
77
sampled_tokens_list = []
@@ -78,8 +80,10 @@ def main(argv):
78
80
while True :
79
81
if profiling_output :
80
82
jax .profiler .start_trace (profiling_output )
81
- decode_state , result_tokens = engine .generate (params , decode_state )
82
- result_tokens = result_tokens .convert_to_numpy ()
83
+ decode_state , result_tokens = engine .generate (params , decode_state )
84
+ result_tokens = result_tokens .convert_to_numpy ()
85
+
86
+ if profiling_output :
83
87
jax .profiler .stop_trace ()
84
88
output , complete = token_utils .process_result_tokens (
85
89
tokenizer = tokenizer ,
0 commit comments