27
27
flags .DEFINE_integer ("max_output_length" , 1024 , "The batch size" )
28
28
flags .DEFINE_integer ("port" , 9000 , "port to listen on" )
29
29
flags .DEFINE_integer ("threads" , 64 , "number of worker threads in thread pool" )
30
- flags .DEFINE_string ("benchmark_save_offline_result_to_file" , "" , "if set, then save the result to the given file name" )
30
+ flags .DEFINE_string (
31
+ "benchmark_save_offline_result_to_file" ,
32
+ "" ,
33
+ "if set, then save the result to the given file name" ,
34
+ )
31
35
32
36
33
37
def shard_weights (env , weights , weight_shardings ):
@@ -115,21 +119,24 @@ def _check_model_id():
115
119
list_model ()
116
120
sys .exit (1 )
117
121
118
- def _run_prefill_time (engine , params , decode_state , seqlen , profiler_started ):
122
+
123
+ def _run_prefill_time (
124
+ pt_engine , params , decode_state , seqlen , profiler_started
125
+ ):
119
126
"""Run prefill and measure time."""
120
- metadata = engine .get_tokenizer ()
121
- tokenizer = engine .build_tokenizer (metadata )
127
+ metadata = pt_engine .get_tokenizer ()
128
+ tokenizer = pt_engine .build_tokenizer (metadata )
122
129
123
130
text = "This is a beautiful day"
124
131
tokens , true_length = tokenizer .encode (
125
132
text , is_bos = True , prefill_lengths = [seqlen ]
126
133
)
127
134
128
135
for _ in range (3 ):
129
- prefill_result , _ = engine .prefill (
136
+ prefill_result , _ = pt_engine .prefill (
130
137
params = params , padded_tokens = tokens , true_length = true_length
131
138
)
132
- decode_state = engine .insert (
139
+ decode_state = pt_engine .insert (
133
140
prefill_result , decode_state , slot = jnp .int32 (1 )
134
141
)
135
142
@@ -140,10 +147,10 @@ def _run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
140
147
jax .profiler .start_trace (FLAGS .profiling_output )
141
148
profiler_started = True
142
149
143
- prefill_result , _ = engine .prefill (
150
+ prefill_result , _ = pt_engine .prefill (
144
151
params = params , padded_tokens = tokens , true_length = true_length
145
152
)
146
- decode_state = engine .insert (
153
+ decode_state = pt_engine .insert (
147
154
prefill_result , decode_state , slot = jnp .int32 (i )
148
155
)
149
156
jax .block_until_ready (decode_state )
@@ -244,25 +251,25 @@ def interactive():
244
251
print ("---- All output text." )
245
252
print (tokenizer .decode (sampled_tokens_list ))
246
253
254
+
247
255
def _save_benchmark_to_file (filename , prefill_times_ms , decode_time_ms ):
248
- lines = [
249
- " # Offline benchmark numbers" ,
250
- " ## Model: " + FLAGS . model_id ,
251
- " ## Batch size: {}" . format ( FLAGS .override_batch_size ) ,
252
- " ## Quantize : {}" . format ( FLAGS .quantize_weights ) ,
253
- " | | time (ms) | " ,
254
- " |-------|----------- |" ,
255
- ] + [
256
- "| Prefill {} | {} |" . format ( x , y ) for x , y in prefill_times_ms . items ()
257
- ] + [
258
- "| Decode | {} |" . format ( decode_time_ms )
259
- ]
260
- with open (filename , 'w' ) as f :
261
- f .write (' \n ' .join (lines ))
256
+ lines = (
257
+ [
258
+ " # Offline benchmark numbers" ,
259
+ " ## Model: " + FLAGS .model_id ,
260
+ f " ## Batch size : { FLAGS .override_batch_size } " ,
261
+ f" ## Quantize: { FLAGS . quantize_weights } " ,
262
+ " | | time (ms) |" ,
263
+ " |-------|-----------|" ,
264
+ ]
265
+ + [f"| Prefill { x } | { y } |" for x , y in prefill_times_ms . items ()]
266
+ + [ f "| Decode | { decode_time_ms } |"]
267
+ )
268
+ with open (filename , "w" , encoding = "utf-8" ) as f :
269
+ f .write (" \n " .join (lines ))
262
270
f .flush ()
263
271
264
272
265
-
266
273
def benchmark_offline ():
267
274
"""function to run engine offline."""
268
275
_check_model_id ()
@@ -280,7 +287,7 @@ def benchmark_offline():
280
287
profiler_started = False
281
288
# 16 .. 1024
282
289
for exp in range (4 , 11 ):
283
- batch = 2 ** exp
290
+ batch = 2 ** exp
284
291
runtime , decode_state , profiler_started = _run_prefill_time (
285
292
pt_engine , params , decode_state , batch , profiler_started
286
293
)
@@ -333,13 +340,12 @@ def benchmark_offline():
333
340
334
341
if FLAGS .benchmark_save_offline_result_to_file :
335
342
_save_benchmark_to_file (
336
- FLAGS .benchmark_save_offline_result_to_file ,
337
- prefill_times_ms ,
338
- decode_time_ms
343
+ FLAGS .benchmark_save_offline_result_to_file ,
344
+ prefill_times_ms ,
345
+ decode_time_ms ,
339
346
)
340
347
341
348
342
-
343
349
def main ():
344
350
"""Main function."""
345
351
0 commit comments