Skip to content

Commit dc0921e

Browse files
qihqiwang2yn84
authored andcommitted
Add mlperf benchmark scripts in-tree. (#148)
1 parent 5855b3d commit dc0921e

12 files changed

+1141
-2
lines changed

benchmarks/mixtral_offline.sh

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
CACHE_LENGTH=1024
2+
INPUT_SIZE=512
3+
OUTPUT_SIZE=1024
4+
BATCH_SIZE=512
5+
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/
6+
7+
pushd ..
8+
python -m benchmarks.run_offline \
9+
--model_name=mixtral \
10+
--batch_size=$BATCH_SIZE \
11+
--max_cache_length=$CACHE_LENGTH \
12+
--max_decode_length=$OUTPUT_SIZE \
13+
--context_length=$INPUT_SIZE \
14+
--checkpoint_path=$CHECKPOINT_PATH/model.safetensors \
15+
--tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \
16+
--quantize_weights=1 \
17+
--quantize_type=int8_per_channel \
18+
--quantize_kv_cache=1 \
19+
--profiling_output=/mnt/disks/hanq/mixtral-profiles
20+
popd

benchmarks/run_offline.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,20 @@ def main(argv):
127127
jax.profiler.stop_trace()
128128

129129
print("prefill ", prefill_times)
130-
print("decode", sum(dec_times) / 10)
130+
avg_decode_times = sum(dec_times[2:]) / len(dec_times[2:])
131+
print("decode", avg_decode_times)
131132

132133
prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()}
133-
decode_time_ms = sum(dec_times) * 1000 / 10 / FLAGS.batch_size
134+
decode_time_ms = sum(dec_times[2:]) * 1000 / 8
135+
136+
largest_prefill = max(prefill_times.items())
137+
print("MAX tokens:", FLAGS.batch_size / avg_decode_times)
138+
139+
time2 = (FLAGS.batch_size * FLAGS.max_decode_length) / (
140+
FLAGS.batch_size * largest_prefill[1]
141+
+ FLAGS.max_decode_length * avg_decode_times
142+
)
143+
print("MAX tokens 2:", time2)
134144

135145
sharegpt_path = FLAGS.sharegpt_path
136146
if sharegpt_path:

mlperf/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Run MLPerf tests
2+
3+
NOTE: currently only tried with mixtral;
4+
and only tried with offline benchmark
5+
6+
# How to run
7+
8+
### 1. Install
9+
10+
```
11+
./install.sh
12+
```
13+
14+
### 2. Start server
15+
16+
```
17+
./start_server.sh
18+
```
19+
20+
### 3. Warm up the server
21+
22+
```
23+
python warmup.py
24+
```
25+
26+
### 4. Run the benchmark, now it runs offline mode
27+
28+
```
29+
./benchmark_run.sh
30+
```
31+

0 commit comments

Comments
 (0)