5
5
# import torch_xla2 first!
6
6
import torch_xla2 # pylint: disable
7
7
import jax
8
+ from jax import numpy as jnp
8
9
from absl import app , flags
9
10
from jetstream .engine import token_utils
10
11
from jetstream .core import server_lib
26
27
flags .DEFINE_integer ("max_output_length" , 1024 , "The batch size" )
27
28
flags .DEFINE_integer ("port" , 9000 , "port to listen on" )
28
29
flags .DEFINE_integer ("threads" , 64 , "number of worker threads in thread pool" )
30
+ flags .DEFINE_string (
31
+ "benchmark_save_offline_result_to_file" ,
32
+ "" ,
33
+ "if set, then save the result to the given file name" ,
34
+ )
29
35
30
36
31
37
def shard_weights (env , weights , weight_shardings ):
@@ -114,6 +120,45 @@ def _check_model_id():
114
120
sys .exit (1 )
115
121
116
122
123
+ def _run_prefill_time (
124
+ pt_engine , params , decode_state , seqlen , profiler_started
125
+ ):
126
+ """Run prefill and measure time."""
127
+ metadata = pt_engine .get_tokenizer ()
128
+ tokenizer = pt_engine .build_tokenizer (metadata )
129
+
130
+ text = "This is a beautiful day"
131
+ tokens , true_length = tokenizer .encode (
132
+ text , is_bos = True , prefill_lengths = [seqlen ]
133
+ )
134
+
135
+ for _ in range (3 ):
136
+ prefill_result , _ = pt_engine .prefill (
137
+ params = params , padded_tokens = tokens , true_length = true_length
138
+ )
139
+ decode_state = pt_engine .insert (
140
+ prefill_result , decode_state , slot = jnp .int32 (1 )
141
+ )
142
+
143
+ nums = 5
144
+ start = time .perf_counter ()
145
+ for i in range (nums ):
146
+ if i == nums - 1 and FLAGS .profiling_prefill and not profiler_started :
147
+ jax .profiler .start_trace (FLAGS .profiling_output )
148
+ profiler_started = True
149
+
150
+ prefill_result , _ = pt_engine .prefill (
151
+ params = params , padded_tokens = tokens , true_length = true_length
152
+ )
153
+ decode_state = pt_engine .insert (
154
+ prefill_result , decode_state , slot = jnp .int32 (i )
155
+ )
156
+ jax .block_until_ready (decode_state )
157
+
158
+ end = time .perf_counter ()
159
+ return (end - start ) / nums , decode_state , profiler_started
160
+
161
+
117
162
def interactive ():
118
163
"""Run interactive"""
119
164
_check_model_id ()
@@ -207,6 +252,100 @@ def interactive():
207
252
print (tokenizer .decode (sampled_tokens_list ))
208
253
209
254
255
+ def _save_benchmark_to_file (filename , prefill_times_ms , decode_time_ms ):
256
+ lines = (
257
+ [
258
+ " # Offline benchmark numbers" ,
259
+ " ## Model: " + FLAGS .model_id ,
260
+ f" ## Batch size: { FLAGS .override_batch_size } " ,
261
+ f" ## Quantize: { FLAGS .quantize_weights } " ,
262
+ " | | time (ms) |" ,
263
+ " |-------|-----------|" ,
264
+ ]
265
+ + [f"| Prefill { x } | { y } |" for x , y in prefill_times_ms .items ()]
266
+ + [f"| Decode | { decode_time_ms } |" ]
267
+ )
268
+ with open (filename , "w" , encoding = "utf-8" ) as f :
269
+ f .write ("\n " .join (lines ))
270
+ f .flush ()
271
+
272
+
273
+ def benchmark_offline ():
274
+ """function to run engine offline."""
275
+ _check_model_id ()
276
+ devices = server_lib .get_devices ()
277
+ print (f"devices: { devices } " )
278
+ pt_engine = create_engine (devices )
279
+
280
+ start = time .perf_counter ()
281
+ params = pt_engine .load_params ()
282
+ print ("Load params " , time .perf_counter () - start )
283
+
284
+ prefill_times = {}
285
+
286
+ decode_state = pt_engine .init_decode_state ()
287
+ profiler_started = False
288
+ # 16 .. 1024
289
+ for exp in range (4 , 11 ):
290
+ batch = 2 ** exp
291
+ runtime , decode_state , profiler_started = _run_prefill_time (
292
+ pt_engine , params , decode_state , batch , profiler_started
293
+ )
294
+ prefill_times [batch ] = runtime
295
+
296
+ sampled_tokens_list = []
297
+
298
+ for i in range (3 ): # warm up
299
+ # pylint: disable-next=all
300
+ decode_state , sampled_tokens = pt_engine .generate (
301
+ params = params , decode_state = decode_state
302
+ )
303
+ sampled_tokens_list .append (sampled_tokens )
304
+
305
+ profiling_output = FLAGS .profiling_output
306
+ print ("======= decode starting ===" )
307
+
308
+ dec_times = []
309
+ for i in range (10 ):
310
+ if profiling_output and i == 7 and not profiler_started :
311
+ jax .profiler .start_trace (profiling_output )
312
+ profiler_started = True
313
+ start = time .perf_counter ()
314
+ # pylint: disable-next=all
315
+ decode_state , sampled_tokens = pt_engine .generate (params , decode_state )
316
+ jax .block_until_ready (decode_state )
317
+ sampled_tokens_list .append (sampled_tokens )
318
+ end = time .perf_counter ()
319
+ dec_times .append (end - start )
320
+ print (i , "decode time" , (end - start ))
321
+
322
+ if profiler_started :
323
+ jax .profiler .stop_trace ()
324
+
325
+ print ("prefill " , prefill_times )
326
+ avg_decode_times = sum (dec_times [2 :]) / len (dec_times [2 :])
327
+ print ("decode" , avg_decode_times )
328
+
329
+ prefill_times_ms = {k : v * 1000 for k , v in prefill_times .items ()}
330
+ decode_time_ms = sum (dec_times [2 :]) * 1000 / 8
331
+
332
+ largest_prefill = max (prefill_times .items ())
333
+ print ("MAX tokens:" , FLAGS .batch_size / avg_decode_times )
334
+
335
+ time2 = (FLAGS .batch_size * FLAGS .max_decode_length ) / (
336
+ FLAGS .batch_size * largest_prefill [1 ]
337
+ + FLAGS .max_decode_length * avg_decode_times
338
+ )
339
+ print ("MAX tokens 2:" , time2 )
340
+
341
+ if FLAGS .benchmark_save_offline_result_to_file :
342
+ _save_benchmark_to_file (
343
+ FLAGS .benchmark_save_offline_result_to_file ,
344
+ prefill_times_ms ,
345
+ decode_time_ms ,
346
+ )
347
+
348
+
210
349
def main ():
211
350
"""Main function."""
212
351
@@ -221,6 +360,8 @@ def main_real(argv):
221
360
serve ()
222
361
elif argv [1 ] == "interactive" :
223
362
interactive ()
363
+ elif argv [1 ] == "benchmark_offline" :
364
+ benchmark_offline ()
224
365
else :
225
366
print (
226
367
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
0 commit comments