Skip to content

Commit 60919fd

Browse files
committed
Add run_offline; also random weights
1 parent ac30781 commit 60919fd

File tree

4 files changed

+152
-6
lines changed

4 files changed

+152
-6
lines changed

.github/workflows/offline_perf.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ jobs:
4949
run: |
5050
set -euo pipefail
5151
source venv/bin/activate
52-
python benchmarks/basic_ops.py | ./jq-linux-amd64 -Rsa . | tee output.txt
52+
JAX_PLATFORMS=tpu,cpu python -m jetstream_pt.cli benchmark_offline --model_id meta-llama/Meta-Llama-3-8B-Instruct --quantize_weights=0 --override_batch_size=128 --benchmark_save_offline_result_to_file=result.md --internal_use_random_weights=True
53+
cat result.md | ./jd-linux-amd64 > output.txt
5354
- name: Update result to PR
5455
env:
5556
URL: ${{ github.event.pull_request.comments_url }}

benchmarks/run_offline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def main(argv):
9292

9393
decode_state = engine.init_decode_state()
9494
profiler_started = False
95-
for batch, _ in MAXTEXT_PREFILL.items():
95+
for exp in range(4, 11):
96+
batch = 2 ** exp
9697
runtime, decode_state, profiler_started = run_prefill_time(
9798
engine, params, decode_state, batch, profiler_started
9899
)

jetstream_pt/cli.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# import torch_xla2 first!
66
import torch_xla2 # pylint: disable
77
import jax
8+
from jax import numpy as jnp
89
from absl import app, flags
910
from jetstream.engine import token_utils
1011
from jetstream.core import server_lib
@@ -26,6 +27,7 @@
2627
flags.DEFINE_integer("max_output_length", 1024, "The batch size")
2728
flags.DEFINE_integer("port", 9000, "port to listen on")
2829
flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool")
30+
flags.DEFINE_string("benchmark_save_offline_result_to_file", "", "if set, then save the result to the given file name")
2931

3032

3133
def shard_weights(env, weights, weight_shardings):
@@ -113,6 +115,42 @@ def _check_model_id():
113115
list_model()
114116
sys.exit(1)
115117

118+
def _run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
119+
"""Run prefill and measure time."""
120+
metadata = engine.get_tokenizer()
121+
tokenizer = engine.build_tokenizer(metadata)
122+
123+
text = "This is a beautiful day"
124+
tokens, true_length = tokenizer.encode(
125+
text, is_bos=True, prefill_lengths=[seqlen]
126+
)
127+
128+
for _ in range(3):
129+
prefill_result, _ = engine.prefill(
130+
params=params, padded_tokens=tokens, true_length=true_length
131+
)
132+
decode_state = engine.insert(
133+
prefill_result, decode_state, slot=jnp.int32(1)
134+
)
135+
136+
nums = 5
137+
start = time.perf_counter()
138+
for i in range(nums):
139+
if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started:
140+
jax.profiler.start_trace(FLAGS.profiling_output)
141+
profiler_started = True
142+
143+
prefill_result, _ = engine.prefill(
144+
params=params, padded_tokens=tokens, true_length=true_length
145+
)
146+
decode_state = engine.insert(
147+
prefill_result, decode_state, slot=jnp.int32(i)
148+
)
149+
jax.block_until_ready(decode_state)
150+
151+
end = time.perf_counter()
152+
return (end - start) / nums, decode_state, profiler_started
153+
116154

117155
def interactive():
118156
"""Run interactive"""
@@ -206,6 +244,99 @@ def interactive():
206244
print("---- All output text.")
207245
print(tokenizer.decode(sampled_tokens_list))
208246

247+
def _save_benchmark_to_file(filename, prefill_times_ms, decode_time_ms):
248+
lines = [
249+
" # Offline benchmark numbers",
250+
" ## Command: " + " ".join(sys.argv),
251+
" | | time (ms) |",
252+
" |-------|-----------|",
253+
] + [
254+
"| Prefill {} | {} |".format(x, y) for x, y in prefill_times_ms.items()
255+
] + [
256+
"| Decode | {} |".format(decode_time_ms)
257+
]
258+
with open(filename, 'w') as f:
259+
f.write('\n'.join(lines))
260+
f.flush()
261+
262+
263+
264+
def benchmark_offline():
265+
"""function to run engine offline."""
266+
_check_model_id()
267+
devices = server_lib.get_devices()
268+
print(f"devices: {devices}")
269+
pt_engine = create_engine(devices)
270+
271+
start = time.perf_counter()
272+
params = pt_engine.load_params()
273+
print("Load params ", time.perf_counter() - start)
274+
275+
prefill_times = {}
276+
277+
decode_state = pt_engine.init_decode_state()
278+
profiler_started = False
279+
# 16 .. 1024
280+
for exp in range(4, 11):
281+
batch = 2 ** exp
282+
runtime, decode_state, profiler_started = _run_prefill_time(
283+
pt_engine, params, decode_state, batch, profiler_started
284+
)
285+
prefill_times[batch] = runtime
286+
287+
sampled_tokens_list = []
288+
289+
for i in range(3): # warm up
290+
# pylint: disable-next=all
291+
decode_state, sampled_tokens = pt_engine.generate(
292+
params=params, decode_state=decode_state
293+
)
294+
sampled_tokens_list.append(sampled_tokens)
295+
296+
profiling_output = FLAGS.profiling_output
297+
print("======= decode starting ===")
298+
299+
dec_times = []
300+
for i in range(10):
301+
if profiling_output and i == 7 and not profiler_started:
302+
jax.profiler.start_trace(profiling_output)
303+
profiler_started = True
304+
start = time.perf_counter()
305+
# pylint: disable-next=all
306+
decode_state, sampled_tokens = pt_engine.generate(params, decode_state)
307+
jax.block_until_ready(decode_state)
308+
sampled_tokens_list.append(sampled_tokens)
309+
end = time.perf_counter()
310+
dec_times.append(end - start)
311+
print(i, "decode time", (end - start))
312+
313+
if profiler_started:
314+
jax.profiler.stop_trace()
315+
316+
print("prefill ", prefill_times)
317+
avg_decode_times = sum(dec_times[2:]) / len(dec_times[2:])
318+
print("decode", avg_decode_times)
319+
320+
prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()}
321+
decode_time_ms = sum(dec_times[2:]) * 1000 / 8
322+
323+
largest_prefill = max(prefill_times.items())
324+
print("MAX tokens:", FLAGS.batch_size / avg_decode_times)
325+
326+
time2 = (FLAGS.batch_size * FLAGS.max_decode_length) / (
327+
FLAGS.batch_size * largest_prefill[1]
328+
+ FLAGS.max_decode_length * avg_decode_times
329+
)
330+
print("MAX tokens 2:", time2)
331+
332+
if FLAGS.benchmark_save_offline_result_to_file:
333+
_save_benchmark_to_file(
334+
FLAGS.benchmark_save_offline_result_to_file,
335+
prefill_times_ms,
336+
decode_time_ms
337+
)
338+
339+
209340

210341
def main():
211342
"""Main function."""
@@ -221,6 +352,8 @@ def main_real(argv):
221352
serve()
222353
elif argv[1] == "interactive":
223354
interactive()
355+
elif argv[1] == "benchmark_offline":
356+
benchmark_offline()
224357
else:
225358
print(
226359
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."

jetstream_pt/fetch_models.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"Directory to store downloaded/converted weights",
2424
)
2525
flags.DEFINE_string("hf_token", "", "huggingface token")
26+
flags.DEFINE_bool("internal_use_random_weights", False, "Use random weights instead of HF weights. Testing only.")
2627

2728
flags.DEFINE_integer(
2829
"override_max_cache_length",
@@ -157,24 +158,34 @@ def _load_weights(directory):
157158
# Load the state_dict into the model
158159
return state_dict
159160

161+
def _make_random_model_weights(model):
162+
result = {}
163+
for key, val in model.state_dict().items():
164+
new_weights = torch.rand(val.shape, dtype=val.dtype, device='cpu')
165+
result[key] = new_weights
166+
return result
167+
160168

161169
def instantiate_model_from_repo_id(
162170
repo_id,
163171
env,
164172
):
165173
"""Create model instance by hf model id.+"""
166174
model_dir = _hf_dir(repo_id)
167-
if not os.path.exists(model_dir) or not os.listdir(model_dir):
175+
if not FLAGS.internal_use_random_weights and (not os.path.exists(model_dir) or
176+
not os.listdir(model_dir)):
168177
# no weights has been downloaded
169178
_hf_download(repo_id, model_dir, FLAGS.hf_token)
170179
model_info = model_id_to_class.get(repo_id)
171180
assert model_info is not None
172181

173182
env.device = "meta"
174183
model = model_info.model_class.from_hf_model_id(repo_id, env)
175-
weights = _load_weights(model_dir)
176-
weights = model.convert_hf_weights(weights)
177-
184+
if not FLAGS.internal_use_random_weights:
185+
weights = _load_weights(model_dir)
186+
weights = model.convert_hf_weights(weights)
187+
else:
188+
weights = _make_random_model_weights(model)
178189
model.load_state_dict(weights, assign=True, strict=False)
179190

180191
return model

0 commit comments

Comments
 (0)