Skip to content

Commit 674ea81

Browse files
committed
Add run_offline; also random weights
1 parent ac30781 commit 674ea81

File tree

4 files changed

+155
-6
lines changed

4 files changed

+155
-6
lines changed

.github/workflows/offline_perf.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ jobs:
4646
- name: Run offlinebench
4747
env:
4848
JAX_PLATFORMS: tpu,cpu
49+
HF_TOKEN : ${{ secrets.HF_TOKEN}}
4950
run: |
5051
set -euo pipefail
5152
source venv/bin/activate
52-
python benchmarks/basic_ops.py | ./jq-linux-amd64 -Rsa . | tee output.txt
53+
JAX_PLATFORMS=tpu,cpu python -m jetstream_pt.cli benchmark_offline --model_id meta-llama/Meta-Llama-3-8B-Instruct --quantize_weights=0 --override_batch_size=128 --benchmark_save_offline_result_to_file=result.md --internal_use_random_weights=True --hf_token=$HF_TOKEN
54+
cat result.md | ./jq-linux-amd64 > output.txt
5355
- name: Update result to PR
5456
env:
5557
URL: ${{ github.event.pull_request.comments_url }}

benchmarks/run_offline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def main(argv):
9292

9393
decode_state = engine.init_decode_state()
9494
profiler_started = False
95-
for batch, _ in MAXTEXT_PREFILL.items():
95+
for exp in range(4, 11):
96+
batch = 2 ** exp
9697
runtime, decode_state, profiler_started = run_prefill_time(
9798
engine, params, decode_state, batch, profiler_started
9899
)

jetstream_pt/cli.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# import torch_xla2 first!
66
import torch_xla2 # pylint: disable
77
import jax
8+
from jax import numpy as jnp
89
from absl import app, flags
910
from jetstream.engine import token_utils
1011
from jetstream.core import server_lib
@@ -26,6 +27,7 @@
2627
flags.DEFINE_integer("max_output_length", 1024, "The batch size")
2728
flags.DEFINE_integer("port", 9000, "port to listen on")
2829
flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool")
30+
flags.DEFINE_string("benchmark_save_offline_result_to_file", "", "if set, then save the result to the given file name")
2931

3032

3133
def shard_weights(env, weights, weight_shardings):
@@ -113,6 +115,42 @@ def _check_model_id():
113115
list_model()
114116
sys.exit(1)
115117

118+
def _run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
119+
"""Run prefill and measure time."""
120+
metadata = engine.get_tokenizer()
121+
tokenizer = engine.build_tokenizer(metadata)
122+
123+
text = "This is a beautiful day"
124+
tokens, true_length = tokenizer.encode(
125+
text, is_bos=True, prefill_lengths=[seqlen]
126+
)
127+
128+
for _ in range(3):
129+
prefill_result, _ = engine.prefill(
130+
params=params, padded_tokens=tokens, true_length=true_length
131+
)
132+
decode_state = engine.insert(
133+
prefill_result, decode_state, slot=jnp.int32(1)
134+
)
135+
136+
nums = 5
137+
start = time.perf_counter()
138+
for i in range(nums):
139+
if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started:
140+
jax.profiler.start_trace(FLAGS.profiling_output)
141+
profiler_started = True
142+
143+
prefill_result, _ = engine.prefill(
144+
params=params, padded_tokens=tokens, true_length=true_length
145+
)
146+
decode_state = engine.insert(
147+
prefill_result, decode_state, slot=jnp.int32(i)
148+
)
149+
jax.block_until_ready(decode_state)
150+
151+
end = time.perf_counter()
152+
return (end - start) / nums, decode_state, profiler_started
153+
116154

117155
def interactive():
118156
"""Run interactive"""
@@ -206,6 +244,101 @@ def interactive():
206244
print("---- All output text.")
207245
print(tokenizer.decode(sampled_tokens_list))
208246

247+
def _save_benchmark_to_file(filename, prefill_times_ms, decode_time_ms):
248+
lines = [
249+
" # Offline benchmark numbers",
250+
" ## Model: " + FLAGS.model_id,
251+
" ## Batch size: {}".format(FLAGS.override_batch_size),
252+
" ## Quantize: {}".format(FLAGS.quantize_weights),
253+
" | | time (ms) |",
254+
" |-------|-----------|",
255+
] + [
256+
"| Prefill {} | {} |".format(x, y) for x, y in prefill_times_ms.items()
257+
] + [
258+
"| Decode | {} |".format(decode_time_ms)
259+
]
260+
with open(filename, 'w') as f:
261+
f.write('\n'.join(lines))
262+
f.flush()
263+
264+
265+
266+
def benchmark_offline():
267+
"""function to run engine offline."""
268+
_check_model_id()
269+
devices = server_lib.get_devices()
270+
print(f"devices: {devices}")
271+
pt_engine = create_engine(devices)
272+
273+
start = time.perf_counter()
274+
params = pt_engine.load_params()
275+
print("Load params ", time.perf_counter() - start)
276+
277+
prefill_times = {}
278+
279+
decode_state = pt_engine.init_decode_state()
280+
profiler_started = False
281+
# 16 .. 1024
282+
for exp in range(4, 11):
283+
batch = 2 ** exp
284+
runtime, decode_state, profiler_started = _run_prefill_time(
285+
pt_engine, params, decode_state, batch, profiler_started
286+
)
287+
prefill_times[batch] = runtime
288+
289+
sampled_tokens_list = []
290+
291+
for i in range(3): # warm up
292+
# pylint: disable-next=all
293+
decode_state, sampled_tokens = pt_engine.generate(
294+
params=params, decode_state=decode_state
295+
)
296+
sampled_tokens_list.append(sampled_tokens)
297+
298+
profiling_output = FLAGS.profiling_output
299+
print("======= decode starting ===")
300+
301+
dec_times = []
302+
for i in range(10):
303+
if profiling_output and i == 7 and not profiler_started:
304+
jax.profiler.start_trace(profiling_output)
305+
profiler_started = True
306+
start = time.perf_counter()
307+
# pylint: disable-next=all
308+
decode_state, sampled_tokens = pt_engine.generate(params, decode_state)
309+
jax.block_until_ready(decode_state)
310+
sampled_tokens_list.append(sampled_tokens)
311+
end = time.perf_counter()
312+
dec_times.append(end - start)
313+
print(i, "decode time", (end - start))
314+
315+
if profiler_started:
316+
jax.profiler.stop_trace()
317+
318+
print("prefill ", prefill_times)
319+
avg_decode_times = sum(dec_times[2:]) / len(dec_times[2:])
320+
print("decode", avg_decode_times)
321+
322+
prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()}
323+
decode_time_ms = sum(dec_times[2:]) * 1000 / 8
324+
325+
largest_prefill = max(prefill_times.items())
326+
print("MAX tokens:", FLAGS.batch_size / avg_decode_times)
327+
328+
time2 = (FLAGS.batch_size * FLAGS.max_decode_length) / (
329+
FLAGS.batch_size * largest_prefill[1]
330+
+ FLAGS.max_decode_length * avg_decode_times
331+
)
332+
print("MAX tokens 2:", time2)
333+
334+
if FLAGS.benchmark_save_offline_result_to_file:
335+
_save_benchmark_to_file(
336+
FLAGS.benchmark_save_offline_result_to_file,
337+
prefill_times_ms,
338+
decode_time_ms
339+
)
340+
341+
209342

210343
def main():
211344
"""Main function."""
@@ -221,6 +354,8 @@ def main_real(argv):
221354
serve()
222355
elif argv[1] == "interactive":
223356
interactive()
357+
elif argv[1] == "benchmark_offline":
358+
benchmark_offline()
224359
else:
225360
print(
226361
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."

jetstream_pt/fetch_models.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"Directory to store downloaded/converted weights",
2424
)
2525
flags.DEFINE_string("hf_token", "", "huggingface token")
26+
flags.DEFINE_bool("internal_use_random_weights", False, "Use random weights instead of HF weights. Testing only.")
2627

2728
flags.DEFINE_integer(
2829
"override_max_cache_length",
@@ -157,24 +158,34 @@ def _load_weights(directory):
157158
# Load the state_dict into the model
158159
return state_dict
159160

161+
def _make_random_model_weights(model):
162+
result = {}
163+
for key, val in model.state_dict().items():
164+
new_weights = torch.rand(val.shape, dtype=val.dtype, device='cpu')
165+
result[key] = new_weights
166+
return result
167+
160168

161169
def instantiate_model_from_repo_id(
162170
repo_id,
163171
env,
164172
):
165173
"""Create model instance by hf model id.+"""
166174
model_dir = _hf_dir(repo_id)
167-
if not os.path.exists(model_dir) or not os.listdir(model_dir):
175+
if not FLAGS.internal_use_random_weights and (not os.path.exists(model_dir) or
176+
not os.listdir(model_dir)):
168177
# no weights has been downloaded
169178
_hf_download(repo_id, model_dir, FLAGS.hf_token)
170179
model_info = model_id_to_class.get(repo_id)
171180
assert model_info is not None
172181

173182
env.device = "meta"
174183
model = model_info.model_class.from_hf_model_id(repo_id, env)
175-
weights = _load_weights(model_dir)
176-
weights = model.convert_hf_weights(weights)
177-
184+
if not FLAGS.internal_use_random_weights:
185+
weights = _load_weights(model_dir)
186+
weights = model.convert_hf_weights(weights)
187+
else:
188+
weights = _make_random_model_weights(model)
178189
model.load_state_dict(weights, assign=True, strict=False)
179190

180191
return model

0 commit comments

Comments
 (0)