Skip to content

Commit b7a2310

Browse files
committed
make lance's change work for mixtral
1 parent 28e2dfe commit b7a2310

File tree

4 files changed

+150
-5
lines changed

4 files changed

+150
-5
lines changed

benchmarks/mixtral_offline.sh

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
CACHE_LENGTH=$1
2+
BATCH_SIZE=$2
3+
INPUT_SIZE=1024
4+
OUTPUT_SIZE=1024
5+
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/
6+
7+
pushd ..
8+
python -m benchmarks.run_offline \
9+
--lazy_cache_update=1 \
10+
--ring_buffer=0 \
11+
--model_name=mixtral \
12+
--batch_size=$BATCH_SIZE \
13+
--max_cache_length=$CACHE_LENGTH \
14+
--max_decode_length=$OUTPUT_SIZE \
15+
--context_length=$INPUT_SIZE \
16+
--checkpoint_path=$CHECKPOINT_PATH/model.safetensors \
17+
--tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \
18+
--quantize_weights=1 \
19+
--quantize_type=int8_per_channel \
20+
--quantize_kv_cache=1 \
21+
--profiling_output=/mnt/disks/hanq/mixtral-profiles
22+
popd
23+
echo "batch was $2 cache was $1"

benchmarks/offline_benchmark.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import math
2+
import pandas as pd
3+
import dataclasses
4+
from collections import defaultdict
5+
from absl import flags, app
6+
7+
from typing import Dict
8+
9+
FLAGS = flags.FLAGS
10+
11+
flags.DEFINE_string('dataset_path', '', '')
12+
13+
@dataclasses.dataclass
14+
class Stat:
15+
cache_size: int
16+
batch_size: int
17+
prefill_times: Dict[int, float]
18+
decode_time: float
19+
20+
scenario1 = [
21+
Stat(
22+
cache_size = 512,
23+
batch_size = 2048,
24+
prefill_times = {
25+
16: 0.016024088603444397,
26+
32: 0.021154335999926843,
27+
64: 0.02999803279999469,
28+
128: 0.043986773600045125, 256: 0.07524209819985117, 512: 0.13882793779994246},
29+
decode_time = 0.28033976474989686
30+
),
31+
Stat(
32+
cache_size = 1280,
33+
batch_size = 512,
34+
prefill_times = {
35+
16: 0.016024088603444397,
36+
32: 0.020686019999993734, 64: 0.02952769919993443, 128: 0.04383329960000992, 256: 0.07538782240008005, 512: 0.13893127239989553, 1024: 0.2693996697998955},
37+
decode_time=0.11505070800001249,
38+
),
39+
Stat(
40+
cache_size = 3072,
41+
batch_size = 256,
42+
prefill_times = {32: 0.021193669800049976, 64: 0.030565194799964956, 128: 0.04334795760005363, 256: 0.07586566419995507, 512: 0.13899565000010625, 1024: 0.26945373279995694, 2048: 0.35605709000010394},
43+
decode_time = 0.06467210225014242,
44+
)
45+
]
46+
47+
scenario2 = [
48+
Stat(
49+
cache_size = 3072,
50+
batch_size = 256,
51+
prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882},
52+
decode_time = 0.0631,
53+
),
54+
Stat(
55+
cache_size = 3072,
56+
batch_size = 256,
57+
prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882},
58+
decode_time = 0.0631,
59+
),
60+
Stat(
61+
cache_size = 3072,
62+
batch_size = 256,
63+
prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882},
64+
decode_time = 0.0631,
65+
)
66+
]
67+
def eval_scenario(dataset, scenario):
68+
69+
total_input_tokens = 0
70+
total_output_tokens = 0
71+
total_prefill_times = defaultdict(float)
72+
total_decode_times = defaultdict(float)
73+
output_tokens_by_bucket = defaultdict(int)
74+
for _, data in dataset.iterrows():
75+
stat = scenario[data.bucket]
76+
total_input_tokens += data.tok_input_len
77+
total_output_tokens += data.tok_ref_output_len
78+
input_len_bucket = 2**math.ceil(math.log2(data.tok_input_len))
79+
if input_len_bucket == 2048 and data.bucket == 1:
80+
import pdb; pdb.set_trace()
81+
total_prefill_times[input_len_bucket] += stat.prefill_times[input_len_bucket]
82+
output_tokens_by_bucket[data.bucket] += data.tok_ref_output_len
83+
84+
for k in output_tokens_by_bucket.keys():
85+
stat = scenario[k]
86+
total_decode_times[k] = output_tokens_by_bucket[k] / stat.batch_size * scenario[k].decode_time
87+
88+
prefill_total = sum(total_prefill_times.values())
89+
decode_total = sum(total_decode_times.values())
90+
print('Total input tokens', total_input_tokens)
91+
print('Total output tokens', total_output_tokens)
92+
print('Input / output', total_input_tokens / total_output_tokens)
93+
print('Prefill times', total_prefill_times)
94+
print('pref throughput', total_input_tokens / sum(total_prefill_times.values()))
95+
print('decode times', total_decode_times)
96+
print('decode throughput', total_output_tokens / sum(total_decode_times.values()) )
97+
print('overall throughput',
98+
total_output_tokens /
99+
(sum(total_decode_times.values()) + sum(total_prefill_times.values())))
100+
print('prefill total time', prefill_total)
101+
print('decode total time', decode_total)
102+
103+
104+
105+
def main(argv):
106+
dataset = pd.read_pickle(FLAGS.dataset_path)
107+
total_len = dataset.tok_input_len + dataset.tok_ref_output_len
108+
bucket = 0 + (total_len > 512) + ((total_len > 1280) | (dataset.tok_input_len > 1024))
109+
dataset.insert(2, 'bucket', bucket)
110+
eval_scenario(dataset, scenario1)
111+
print('======== scenario 2 ========')
112+
eval_scenario(dataset, scenario2)
113+
114+
if __name__ == '__main__':
115+
app.run(main)
116+
117+

benchmarks/run_offline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
7777
256: 23.59,
7878
512: 35.28,
7979
1024: 60.28,
80+
2048: 60.28,
8081
}
8182

8283

jetstream_pt/third_party/mixtral/model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,15 @@ def forward(
7777
bsz, seqlen = idx.shape
7878
freqs_cis = self.freqs_cis[input_pos]
7979
freqs_cis = freqs_cis.reshape(bsz, seqlen, -1)
80-
assert len(caches) == len(
81-
self.layers
82-
), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match"
83-
for layer, cache in zip(self.layers, caches):
84-
with jax.named_scope("TransformerBlock"):
80+
81+
for layer_id, layer in enumerate(self.layers):
82+
if caches[0].stacked:
83+
cache = caches[0]
84+
else:
85+
cache = caches[layer_id]
86+
# else: # For stacked case, there is only 1 yer of kv cache
87+
88+
with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)):
8589
x = layer(
8690
x,
8791
freqs_cis,

0 commit comments

Comments
 (0)