5
5
# import torch_xla2 first!
6
6
import torch_xla2 # pylint: disable
7
7
import jax
8
+ from jax import numpy as jnp
8
9
from absl import app , flags
9
10
from jetstream .engine import token_utils
10
11
from jetstream .core import server_lib
26
27
flags .DEFINE_integer ("max_output_length" , 1024 , "The batch size" )
27
28
flags .DEFINE_integer ("port" , 9000 , "port to listen on" )
28
29
flags .DEFINE_integer ("threads" , 64 , "number of worker threads in thread pool" )
30
+ flags .DEFINE_string ("benchmark_save_offline_result_to_file" , "" , "if set, then save the result to the given file name" )
29
31
30
32
31
33
def shard_weights (env , weights , weight_shardings ):
@@ -113,6 +115,42 @@ def _check_model_id():
113
115
list_model ()
114
116
sys .exit (1 )
115
117
118
+ def _run_prefill_time (engine , params , decode_state , seqlen , profiler_started ):
119
+ """Run prefill and measure time."""
120
+ metadata = engine .get_tokenizer ()
121
+ tokenizer = engine .build_tokenizer (metadata )
122
+
123
+ text = "This is a beautiful day"
124
+ tokens , true_length = tokenizer .encode (
125
+ text , is_bos = True , prefill_lengths = [seqlen ]
126
+ )
127
+
128
+ for _ in range (3 ):
129
+ prefill_result , _ = engine .prefill (
130
+ params = params , padded_tokens = tokens , true_length = true_length
131
+ )
132
+ decode_state = engine .insert (
133
+ prefill_result , decode_state , slot = jnp .int32 (1 )
134
+ )
135
+
136
+ nums = 5
137
+ start = time .perf_counter ()
138
+ for i in range (nums ):
139
+ if i == nums - 1 and FLAGS .profiling_prefill and not profiler_started :
140
+ jax .profiler .start_trace (FLAGS .profiling_output )
141
+ profiler_started = True
142
+
143
+ prefill_result , _ = engine .prefill (
144
+ params = params , padded_tokens = tokens , true_length = true_length
145
+ )
146
+ decode_state = engine .insert (
147
+ prefill_result , decode_state , slot = jnp .int32 (i )
148
+ )
149
+ jax .block_until_ready (decode_state )
150
+
151
+ end = time .perf_counter ()
152
+ return (end - start ) / nums , decode_state , profiler_started
153
+
116
154
117
155
def interactive ():
118
156
"""Run interactive"""
@@ -206,6 +244,99 @@ def interactive():
206
244
print ("---- All output text." )
207
245
print (tokenizer .decode (sampled_tokens_list ))
208
246
247
+ def _save_benchmark_to_file (filename , prefill_times_ms , decode_time_ms ):
248
+ lines = [
249
+ " # Offline benchmark numbers" ,
250
+ " ## Command: " + " " .join (sys .argv ),
251
+ " | | time (ms) |" ,
252
+ " |-------|-----------|" ,
253
+ ] + [
254
+ "| Prefill {} | {} |" .format (x , y ) for x , y in prefill_times_ms .items ()
255
+ ] + [
256
+ "| Decode | {} |" .format (decode_time_ms )
257
+ ]
258
+ with open (filename , 'w' ) as f :
259
+ f .write ('\n ' .join (lines ))
260
+ f .flush ()
261
+
262
+
263
+
264
+ def benchmark_offline ():
265
+ """function to run engine offline."""
266
+ _check_model_id ()
267
+ devices = server_lib .get_devices ()
268
+ print (f"devices: { devices } " )
269
+ pt_engine = create_engine (devices )
270
+
271
+ start = time .perf_counter ()
272
+ params = pt_engine .load_params ()
273
+ print ("Load params " , time .perf_counter () - start )
274
+
275
+ prefill_times = {}
276
+
277
+ decode_state = pt_engine .init_decode_state ()
278
+ profiler_started = False
279
+ # 16 .. 1024
280
+ for exp in range (4 , 11 ):
281
+ batch = 2 ** exp
282
+ runtime , decode_state , profiler_started = _run_prefill_time (
283
+ pt_engine , params , decode_state , batch , profiler_started
284
+ )
285
+ prefill_times [batch ] = runtime
286
+
287
+ sampled_tokens_list = []
288
+
289
+ for i in range (3 ): # warm up
290
+ # pylint: disable-next=all
291
+ decode_state , sampled_tokens = pt_engine .generate (
292
+ params = params , decode_state = decode_state
293
+ )
294
+ sampled_tokens_list .append (sampled_tokens )
295
+
296
+ profiling_output = FLAGS .profiling_output
297
+ print ("======= decode starting ===" )
298
+
299
+ dec_times = []
300
+ for i in range (10 ):
301
+ if profiling_output and i == 7 and not profiler_started :
302
+ jax .profiler .start_trace (profiling_output )
303
+ profiler_started = True
304
+ start = time .perf_counter ()
305
+ # pylint: disable-next=all
306
+ decode_state , sampled_tokens = pt_engine .generate (params , decode_state )
307
+ jax .block_until_ready (decode_state )
308
+ sampled_tokens_list .append (sampled_tokens )
309
+ end = time .perf_counter ()
310
+ dec_times .append (end - start )
311
+ print (i , "decode time" , (end - start ))
312
+
313
+ if profiler_started :
314
+ jax .profiler .stop_trace ()
315
+
316
+ print ("prefill " , prefill_times )
317
+ avg_decode_times = sum (dec_times [2 :]) / len (dec_times [2 :])
318
+ print ("decode" , avg_decode_times )
319
+
320
+ prefill_times_ms = {k : v * 1000 for k , v in prefill_times .items ()}
321
+ decode_time_ms = sum (dec_times [2 :]) * 1000 / 8
322
+
323
+ largest_prefill = max (prefill_times .items ())
324
+ print ("MAX tokens:" , FLAGS .batch_size / avg_decode_times )
325
+
326
+ time2 = (FLAGS .batch_size * FLAGS .max_decode_length ) / (
327
+ FLAGS .batch_size * largest_prefill [1 ]
328
+ + FLAGS .max_decode_length * avg_decode_times
329
+ )
330
+ print ("MAX tokens 2:" , time2 )
331
+
332
+ if FLAGS .benchmark_save_offline_result_to_file :
333
+ _save_benchmark_to_file (
334
+ FLAGS .benchmark_save_offline_result_to_file ,
335
+ prefill_times_ms ,
336
+ decode_time_ms
337
+ )
338
+
339
+
209
340
210
341
def main ():
211
342
"""Main function."""
@@ -221,6 +352,8 @@ def main_real(argv):
221
352
serve ()
222
353
elif argv [1 ] == "interactive" :
223
354
interactive ()
355
+ elif argv [1 ] == "benchmark_offline" :
356
+ benchmark_offline ()
224
357
else :
225
358
print (
226
359
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
0 commit comments