5
5
# import torch_xla2 first!
6
6
import torch_xla2 # pylint: disable
7
7
import jax
8
+ from jax import numpy as jnp
8
9
from absl import app , flags
9
10
from jetstream .engine import token_utils
10
11
from jetstream .core import server_lib
26
27
flags .DEFINE_integer ("max_output_length" , 1024 , "The batch size" )
27
28
flags .DEFINE_integer ("port" , 9000 , "port to listen on" )
28
29
flags .DEFINE_integer ("threads" , 64 , "number of worker threads in thread pool" )
30
+ flags .DEFINE_string ("benchmark_save_offline_result_to_file" , "" , "if set, then save the result to the given file name" )
29
31
30
32
31
33
def shard_weights (env , weights , weight_shardings ):
@@ -113,6 +115,42 @@ def _check_model_id():
113
115
list_model ()
114
116
sys .exit (1 )
115
117
118
+ def _run_prefill_time (engine , params , decode_state , seqlen , profiler_started ):
119
+ """Run prefill and measure time."""
120
+ metadata = engine .get_tokenizer ()
121
+ tokenizer = engine .build_tokenizer (metadata )
122
+
123
+ text = "This is a beautiful day"
124
+ tokens , true_length = tokenizer .encode (
125
+ text , is_bos = True , prefill_lengths = [seqlen ]
126
+ )
127
+
128
+ for _ in range (3 ):
129
+ prefill_result , _ = engine .prefill (
130
+ params = params , padded_tokens = tokens , true_length = true_length
131
+ )
132
+ decode_state = engine .insert (
133
+ prefill_result , decode_state , slot = jnp .int32 (1 )
134
+ )
135
+
136
+ nums = 5
137
+ start = time .perf_counter ()
138
+ for i in range (nums ):
139
+ if i == nums - 1 and FLAGS .profiling_prefill and not profiler_started :
140
+ jax .profiler .start_trace (FLAGS .profiling_output )
141
+ profiler_started = True
142
+
143
+ prefill_result , _ = engine .prefill (
144
+ params = params , padded_tokens = tokens , true_length = true_length
145
+ )
146
+ decode_state = engine .insert (
147
+ prefill_result , decode_state , slot = jnp .int32 (i )
148
+ )
149
+ jax .block_until_ready (decode_state )
150
+
151
+ end = time .perf_counter ()
152
+ return (end - start ) / nums , decode_state , profiler_started
153
+
116
154
117
155
def interactive ():
118
156
"""Run interactive"""
@@ -206,6 +244,101 @@ def interactive():
206
244
print ("---- All output text." )
207
245
print (tokenizer .decode (sampled_tokens_list ))
208
246
247
+ def _save_benchmark_to_file (filename , prefill_times_ms , decode_time_ms ):
248
+ lines = [
249
+ " # Offline benchmark numbers" ,
250
+ " ## Model: " + FLAGS .model_id ,
251
+ " ## Batch size: {}" .format (FLAGS .override_batch_size ),
252
+ " ## Quantize: {}" .format (FLAGS .quantize_weights ),
253
+ " | | time (ms) |" ,
254
+ " |-------|-----------|" ,
255
+ ] + [
256
+ "| Prefill {} | {} |" .format (x , y ) for x , y in prefill_times_ms .items ()
257
+ ] + [
258
+ "| Decode | {} |" .format (decode_time_ms )
259
+ ]
260
+ with open (filename , 'w' ) as f :
261
+ f .write ('\n ' .join (lines ))
262
+ f .flush ()
263
+
264
+
265
+
266
+ def benchmark_offline ():
267
+ """function to run engine offline."""
268
+ _check_model_id ()
269
+ devices = server_lib .get_devices ()
270
+ print (f"devices: { devices } " )
271
+ pt_engine = create_engine (devices )
272
+
273
+ start = time .perf_counter ()
274
+ params = pt_engine .load_params ()
275
+ print ("Load params " , time .perf_counter () - start )
276
+
277
+ prefill_times = {}
278
+
279
+ decode_state = pt_engine .init_decode_state ()
280
+ profiler_started = False
281
+ # 16 .. 1024
282
+ for exp in range (4 , 11 ):
283
+ batch = 2 ** exp
284
+ runtime , decode_state , profiler_started = _run_prefill_time (
285
+ pt_engine , params , decode_state , batch , profiler_started
286
+ )
287
+ prefill_times [batch ] = runtime
288
+
289
+ sampled_tokens_list = []
290
+
291
+ for i in range (3 ): # warm up
292
+ # pylint: disable-next=all
293
+ decode_state , sampled_tokens = pt_engine .generate (
294
+ params = params , decode_state = decode_state
295
+ )
296
+ sampled_tokens_list .append (sampled_tokens )
297
+
298
+ profiling_output = FLAGS .profiling_output
299
+ print ("======= decode starting ===" )
300
+
301
+ dec_times = []
302
+ for i in range (10 ):
303
+ if profiling_output and i == 7 and not profiler_started :
304
+ jax .profiler .start_trace (profiling_output )
305
+ profiler_started = True
306
+ start = time .perf_counter ()
307
+ # pylint: disable-next=all
308
+ decode_state , sampled_tokens = pt_engine .generate (params , decode_state )
309
+ jax .block_until_ready (decode_state )
310
+ sampled_tokens_list .append (sampled_tokens )
311
+ end = time .perf_counter ()
312
+ dec_times .append (end - start )
313
+ print (i , "decode time" , (end - start ))
314
+
315
+ if profiler_started :
316
+ jax .profiler .stop_trace ()
317
+
318
+ print ("prefill " , prefill_times )
319
+ avg_decode_times = sum (dec_times [2 :]) / len (dec_times [2 :])
320
+ print ("decode" , avg_decode_times )
321
+
322
+ prefill_times_ms = {k : v * 1000 for k , v in prefill_times .items ()}
323
+ decode_time_ms = sum (dec_times [2 :]) * 1000 / 8
324
+
325
+ largest_prefill = max (prefill_times .items ())
326
+ print ("MAX tokens:" , FLAGS .batch_size / avg_decode_times )
327
+
328
+ time2 = (FLAGS .batch_size * FLAGS .max_decode_length ) / (
329
+ FLAGS .batch_size * largest_prefill [1 ]
330
+ + FLAGS .max_decode_length * avg_decode_times
331
+ )
332
+ print ("MAX tokens 2:" , time2 )
333
+
334
+ if FLAGS .benchmark_save_offline_result_to_file :
335
+ _save_benchmark_to_file (
336
+ FLAGS .benchmark_save_offline_result_to_file ,
337
+ prefill_times_ms ,
338
+ decode_time_ms
339
+ )
340
+
341
+
209
342
210
343
def main ():
211
344
"""Main function."""
@@ -221,6 +354,8 @@ def main_real(argv):
221
354
serve ()
222
355
elif argv [1 ] == "interactive" :
223
356
interactive ()
357
+ elif argv [1 ] == "benchmark_offline" :
358
+ benchmark_offline ()
224
359
else :
225
360
print (
226
361
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
0 commit comments