Skip to content

Commit 5b8823e

Browse files
authored
Add offline perf ci (#181)
* Add offline perf ci * Add run_offline; also random weights * lints * rename variable
1 parent 33348d2 commit 5b8823e

File tree

4 files changed

+232
-7
lines changed

4 files changed

+232
-7
lines changed

.github/workflows/offline_perf.yaml

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
16+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
17+
18+
name: Offline Performance
19+
20+
on:
21+
pull_request:
22+
23+
jobs:
24+
py:
25+
name: "Offline micro benchmark"
26+
strategy:
27+
matrix:
28+
python-version: ['3.10']
29+
runs-on: self-hosted
30+
steps:
31+
- name: Checkout
32+
uses: actions/checkout@v4
33+
- name: Setup Python
34+
uses: actions/setup-python@v4
35+
with:
36+
python-version: ${{ matrix.python-version }}
37+
- name: Install Dependencies
38+
run: |
39+
python -m venv venv
40+
source venv/bin/activate
41+
source install_everything.sh
42+
wget https://github.com/jqlang/jq/releases/download/jq-1.7.1/jq-linux-amd64
43+
chmod +x ./jq-linux-amd64
44+
env
45+
python -c "import jax; print(jax.devices())"
46+
- name: Run offlinebench
47+
env:
48+
JAX_PLATFORMS: tpu,cpu
49+
HF_TOKEN : ${{ secrets.HF_TOKEN}}
50+
run: |
51+
set -euo pipefail
52+
source venv/bin/activate
53+
JAX_PLATFORMS=tpu,cpu python -m jetstream_pt.cli benchmark_offline --model_id meta-llama/Meta-Llama-3-8B-Instruct --quantize_weights=0 --override_batch_size=128 --benchmark_save_offline_result_to_file=result.md --internal_use_random_weights=True --hf_token=$HF_TOKEN
54+
cat result.md | ./jq-linux-amd64 -Rsa . > output.txt
55+
- name: Update result to PR
56+
env:
57+
URL: ${{ github.event.pull_request.comments_url }}
58+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
59+
run: |
60+
curl \
61+
-X POST \
62+
$URL \
63+
-H "Content-Type: application/json" \
64+
-H "Authorization: token $GITHUB_TOKEN" \
65+
--data "{ \"body\": $(cat output.txt) }"
66+

benchmarks/run_offline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,12 @@ def main(argv):
9292

9393
decode_state = engine.init_decode_state()
9494
profiler_started = False
95-
for batch, _ in MAXTEXT_PREFILL.items():
95+
for exp in range(4, 11):
96+
seqlen = 2**exp
9697
runtime, decode_state, profiler_started = run_prefill_time(
97-
engine, params, decode_state, batch, profiler_started
98+
engine, params, decode_state, seqlen, profiler_started
9899
)
99-
prefill_times[batch] = runtime
100+
prefill_times[seqlen] = runtime
100101

101102
sampled_tokens_list = []
102103

jetstream_pt/cli.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# import torch_xla2 first!
66
import torch_xla2 # pylint: disable
77
import jax
8+
from jax import numpy as jnp
89
from absl import app, flags
910
from jetstream.engine import token_utils
1011
from jetstream.core import server_lib
@@ -26,6 +27,11 @@
2627
flags.DEFINE_integer("max_output_length", 1024, "The batch size")
2728
flags.DEFINE_integer("port", 9000, "port to listen on")
2829
flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool")
30+
flags.DEFINE_string(
31+
"benchmark_save_offline_result_to_file",
32+
"",
33+
"if set, then save the result to the given file name",
34+
)
2935

3036

3137
def shard_weights(env, weights, weight_shardings):
@@ -114,6 +120,45 @@ def _check_model_id():
114120
sys.exit(1)
115121

116122

123+
def _run_prefill_time(
124+
pt_engine, params, decode_state, seqlen, profiler_started
125+
):
126+
"""Run prefill and measure time."""
127+
metadata = pt_engine.get_tokenizer()
128+
tokenizer = pt_engine.build_tokenizer(metadata)
129+
130+
text = "This is a beautiful day"
131+
tokens, true_length = tokenizer.encode(
132+
text, is_bos=True, prefill_lengths=[seqlen]
133+
)
134+
135+
for _ in range(3):
136+
prefill_result, _ = pt_engine.prefill(
137+
params=params, padded_tokens=tokens, true_length=true_length
138+
)
139+
decode_state = pt_engine.insert(
140+
prefill_result, decode_state, slot=jnp.int32(1)
141+
)
142+
143+
nums = 5
144+
start = time.perf_counter()
145+
for i in range(nums):
146+
if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started:
147+
jax.profiler.start_trace(FLAGS.profiling_output)
148+
profiler_started = True
149+
150+
prefill_result, _ = pt_engine.prefill(
151+
params=params, padded_tokens=tokens, true_length=true_length
152+
)
153+
decode_state = pt_engine.insert(
154+
prefill_result, decode_state, slot=jnp.int32(i)
155+
)
156+
jax.block_until_ready(decode_state)
157+
158+
end = time.perf_counter()
159+
return (end - start) / nums, decode_state, profiler_started
160+
161+
117162
def interactive():
118163
"""Run interactive"""
119164
_check_model_id()
@@ -207,6 +252,100 @@ def interactive():
207252
print(tokenizer.decode(sampled_tokens_list))
208253

209254

255+
def _save_benchmark_to_file(filename, prefill_times_ms, decode_time_ms):
256+
lines = (
257+
[
258+
" # Offline benchmark numbers",
259+
" ## Model: " + FLAGS.model_id,
260+
f" ## Batch size: {FLAGS.override_batch_size}",
261+
f" ## Quantize: {FLAGS.quantize_weights}",
262+
" | | time (ms) |",
263+
" |-------|-----------|",
264+
]
265+
+ [f"| Prefill {x} | {y} |" for x, y in prefill_times_ms.items()]
266+
+ [f"| Decode | {decode_time_ms} |"]
267+
)
268+
with open(filename, "w", encoding="utf-8") as f:
269+
f.write("\n".join(lines))
270+
f.flush()
271+
272+
273+
def benchmark_offline():
274+
"""function to run engine offline."""
275+
_check_model_id()
276+
devices = server_lib.get_devices()
277+
print(f"devices: {devices}")
278+
pt_engine = create_engine(devices)
279+
280+
start = time.perf_counter()
281+
params = pt_engine.load_params()
282+
print("Load params ", time.perf_counter() - start)
283+
284+
prefill_times = {}
285+
286+
decode_state = pt_engine.init_decode_state()
287+
profiler_started = False
288+
# 16 .. 1024
289+
for exp in range(4, 11):
290+
batch = 2**exp
291+
runtime, decode_state, profiler_started = _run_prefill_time(
292+
pt_engine, params, decode_state, batch, profiler_started
293+
)
294+
prefill_times[batch] = runtime
295+
296+
sampled_tokens_list = []
297+
298+
for i in range(3): # warm up
299+
# pylint: disable-next=all
300+
decode_state, sampled_tokens = pt_engine.generate(
301+
params=params, decode_state=decode_state
302+
)
303+
sampled_tokens_list.append(sampled_tokens)
304+
305+
profiling_output = FLAGS.profiling_output
306+
print("======= decode starting ===")
307+
308+
dec_times = []
309+
for i in range(10):
310+
if profiling_output and i == 7 and not profiler_started:
311+
jax.profiler.start_trace(profiling_output)
312+
profiler_started = True
313+
start = time.perf_counter()
314+
# pylint: disable-next=all
315+
decode_state, sampled_tokens = pt_engine.generate(params, decode_state)
316+
jax.block_until_ready(decode_state)
317+
sampled_tokens_list.append(sampled_tokens)
318+
end = time.perf_counter()
319+
dec_times.append(end - start)
320+
print(i, "decode time", (end - start))
321+
322+
if profiler_started:
323+
jax.profiler.stop_trace()
324+
325+
print("prefill ", prefill_times)
326+
avg_decode_times = sum(dec_times[2:]) / len(dec_times[2:])
327+
print("decode", avg_decode_times)
328+
329+
prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()}
330+
decode_time_ms = sum(dec_times[2:]) * 1000 / 8
331+
332+
largest_prefill = max(prefill_times.items())
333+
print("MAX tokens:", FLAGS.batch_size / avg_decode_times)
334+
335+
time2 = (FLAGS.batch_size * FLAGS.max_decode_length) / (
336+
FLAGS.batch_size * largest_prefill[1]
337+
+ FLAGS.max_decode_length * avg_decode_times
338+
)
339+
print("MAX tokens 2:", time2)
340+
341+
if FLAGS.benchmark_save_offline_result_to_file:
342+
_save_benchmark_to_file(
343+
FLAGS.benchmark_save_offline_result_to_file,
344+
prefill_times_ms,
345+
decode_time_ms,
346+
)
347+
348+
210349
def main():
211350
"""Main function."""
212351

@@ -221,6 +360,8 @@ def main_real(argv):
221360
serve()
222361
elif argv[1] == "interactive":
223362
interactive()
363+
elif argv[1] == "benchmark_offline":
364+
benchmark_offline()
224365
else:
225366
print(
226367
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."

jetstream_pt/fetch_models.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
"Directory to store downloaded/converted weights",
2424
)
2525
flags.DEFINE_string("hf_token", "", "huggingface token")
26+
flags.DEFINE_bool(
27+
"internal_use_random_weights",
28+
False,
29+
"Use random weights instead of HF weights. Testing only.",
30+
)
2631

2732
flags.DEFINE_integer(
2833
"override_max_cache_length",
@@ -158,23 +163,35 @@ def _load_weights(directory):
158163
return state_dict
159164

160165

166+
def _make_random_model_weights(model):
167+
result = {}
168+
for key, val in model.state_dict().items():
169+
new_weights = torch.rand(val.shape, dtype=val.dtype, device="cpu")
170+
result[key] = new_weights
171+
return result
172+
173+
161174
def instantiate_model_from_repo_id(
162175
repo_id,
163176
env,
164177
):
165178
"""Create model instance by hf model id.+"""
166179
model_dir = _hf_dir(repo_id)
167-
if not os.path.exists(model_dir) or not os.listdir(model_dir):
180+
if not FLAGS.internal_use_random_weights and (
181+
not os.path.exists(model_dir) or not os.listdir(model_dir)
182+
):
168183
# no weights has been downloaded
169184
_hf_download(repo_id, model_dir, FLAGS.hf_token)
170185
model_info = model_id_to_class.get(repo_id)
171186
assert model_info is not None
172187

173188
env.device = "meta"
174189
model = model_info.model_class.from_hf_model_id(repo_id, env)
175-
weights = _load_weights(model_dir)
176-
weights = model.convert_hf_weights(weights)
177-
190+
if not FLAGS.internal_use_random_weights:
191+
weights = _load_weights(model_dir)
192+
weights = model.convert_hf_weights(weights)
193+
else:
194+
weights = _make_random_model_weights(model)
178195
model.load_state_dict(weights, assign=True, strict=False)
179196

180197
return model

0 commit comments

Comments
 (0)