Skip to content

Commit 9836998

Browse files
committed
lints
1 parent 674ea81 commit 9836998

File tree

4 files changed

+46
-34
lines changed

4 files changed

+46
-34
lines changed

.github/workflows/offline_perf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
set -euo pipefail
5252
source venv/bin/activate
5353
JAX_PLATFORMS=tpu,cpu python -m jetstream_pt.cli benchmark_offline --model_id meta-llama/Meta-Llama-3-8B-Instruct --quantize_weights=0 --override_batch_size=128 --benchmark_save_offline_result_to_file=result.md --internal_use_random_weights=True --hf_token=$HF_TOKEN
54-
cat result.md | ./jq-linux-amd64 > output.txt
54+
cat result.md | ./jq-linux-amd64 -Rsa . > output.txt
5555
- name: Update result to PR
5656
env:
5757
URL: ${{ github.event.pull_request.comments_url }}

benchmarks/run_offline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def main(argv):
9393
decode_state = engine.init_decode_state()
9494
profiler_started = False
9595
for exp in range(4, 11):
96-
batch = 2 ** exp
96+
batch = 2**exp
9797
runtime, decode_state, profiler_started = run_prefill_time(
9898
engine, params, decode_state, batch, profiler_started
9999
)

jetstream_pt/cli.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
flags.DEFINE_integer("max_output_length", 1024, "The batch size")
2828
flags.DEFINE_integer("port", 9000, "port to listen on")
2929
flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool")
30-
flags.DEFINE_string("benchmark_save_offline_result_to_file", "", "if set, then save the result to the given file name")
30+
flags.DEFINE_string(
31+
"benchmark_save_offline_result_to_file",
32+
"",
33+
"if set, then save the result to the given file name",
34+
)
3135

3236

3337
def shard_weights(env, weights, weight_shardings):
@@ -115,21 +119,24 @@ def _check_model_id():
115119
list_model()
116120
sys.exit(1)
117121

118-
def _run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
122+
123+
def _run_prefill_time(
124+
pt_engine, params, decode_state, seqlen, profiler_started
125+
):
119126
"""Run prefill and measure time."""
120-
metadata = engine.get_tokenizer()
121-
tokenizer = engine.build_tokenizer(metadata)
127+
metadata = pt_engine.get_tokenizer()
128+
tokenizer = pt_engine.build_tokenizer(metadata)
122129

123130
text = "This is a beautiful day"
124131
tokens, true_length = tokenizer.encode(
125132
text, is_bos=True, prefill_lengths=[seqlen]
126133
)
127134

128135
for _ in range(3):
129-
prefill_result, _ = engine.prefill(
136+
prefill_result, _ = pt_engine.prefill(
130137
params=params, padded_tokens=tokens, true_length=true_length
131138
)
132-
decode_state = engine.insert(
139+
decode_state = pt_engine.insert(
133140
prefill_result, decode_state, slot=jnp.int32(1)
134141
)
135142

@@ -140,10 +147,10 @@ def _run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
140147
jax.profiler.start_trace(FLAGS.profiling_output)
141148
profiler_started = True
142149

143-
prefill_result, _ = engine.prefill(
150+
prefill_result, _ = pt_engine.prefill(
144151
params=params, padded_tokens=tokens, true_length=true_length
145152
)
146-
decode_state = engine.insert(
153+
decode_state = pt_engine.insert(
147154
prefill_result, decode_state, slot=jnp.int32(i)
148155
)
149156
jax.block_until_ready(decode_state)
@@ -244,25 +251,25 @@ def interactive():
244251
print("---- All output text.")
245252
print(tokenizer.decode(sampled_tokens_list))
246253

254+
247255
def _save_benchmark_to_file(filename, prefill_times_ms, decode_time_ms):
248-
lines = [
249-
" # Offline benchmark numbers",
250-
" ## Model: " + FLAGS.model_id,
251-
" ## Batch size: {}".format(FLAGS.override_batch_size),
252-
" ## Quantize: {}".format(FLAGS.quantize_weights),
253-
" | | time (ms) |",
254-
" |-------|-----------|",
255-
] + [
256-
"| Prefill {} | {} |".format(x, y) for x, y in prefill_times_ms.items()
257-
] + [
258-
"| Decode | {} |".format(decode_time_ms)
259-
]
260-
with open(filename, 'w') as f:
261-
f.write('\n'.join(lines))
256+
lines = (
257+
[
258+
" # Offline benchmark numbers",
259+
" ## Model: " + FLAGS.model_id,
260+
f" ## Batch size: {FLAGS.override_batch_size}",
261+
f" ## Quantize: {FLAGS.quantize_weights}",
262+
" | | time (ms) |",
263+
" |-------|-----------|",
264+
]
265+
+ [f"| Prefill {x} | {y} |" for x, y in prefill_times_ms.items()]
266+
+ [f"| Decode | {decode_time_ms} |"]
267+
)
268+
with open(filename, "w", encoding="utf-8") as f:
269+
f.write("\n".join(lines))
262270
f.flush()
263271

264272

265-
266273
def benchmark_offline():
267274
"""function to run engine offline."""
268275
_check_model_id()
@@ -280,7 +287,7 @@ def benchmark_offline():
280287
profiler_started = False
281288
# 16 .. 1024
282289
for exp in range(4, 11):
283-
batch = 2 ** exp
290+
batch = 2**exp
284291
runtime, decode_state, profiler_started = _run_prefill_time(
285292
pt_engine, params, decode_state, batch, profiler_started
286293
)
@@ -333,13 +340,12 @@ def benchmark_offline():
333340

334341
if FLAGS.benchmark_save_offline_result_to_file:
335342
_save_benchmark_to_file(
336-
FLAGS.benchmark_save_offline_result_to_file,
337-
prefill_times_ms,
338-
decode_time_ms
343+
FLAGS.benchmark_save_offline_result_to_file,
344+
prefill_times_ms,
345+
decode_time_ms,
339346
)
340347

341348

342-
343349
def main():
344350
"""Main function."""
345351

jetstream_pt/fetch_models.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323
"Directory to store downloaded/converted weights",
2424
)
2525
flags.DEFINE_string("hf_token", "", "huggingface token")
26-
flags.DEFINE_bool("internal_use_random_weights", False, "Use random weights instead of HF weights. Testing only.")
26+
flags.DEFINE_bool(
27+
"internal_use_random_weights",
28+
False,
29+
"Use random weights instead of HF weights. Testing only.",
30+
)
2731

2832
flags.DEFINE_integer(
2933
"override_max_cache_length",
@@ -158,10 +162,11 @@ def _load_weights(directory):
158162
# Load the state_dict into the model
159163
return state_dict
160164

165+
161166
def _make_random_model_weights(model):
162167
result = {}
163168
for key, val in model.state_dict().items():
164-
new_weights = torch.rand(val.shape, dtype=val.dtype, device='cpu')
169+
new_weights = torch.rand(val.shape, dtype=val.dtype, device="cpu")
165170
result[key] = new_weights
166171
return result
167172

@@ -172,8 +177,9 @@ def instantiate_model_from_repo_id(
172177
):
173178
"""Create model instance by hf model id.+"""
174179
model_dir = _hf_dir(repo_id)
175-
if not FLAGS.internal_use_random_weights and (not os.path.exists(model_dir) or
176-
not os.listdir(model_dir)):
180+
if not FLAGS.internal_use_random_weights and (
181+
not os.path.exists(model_dir) or not os.listdir(model_dir)
182+
):
177183
# no weights has been downloaded
178184
_hf_download(repo_id, model_dir, FLAGS.hf_token)
179185
model_info = model_id_to_class.get(repo_id)

0 commit comments

Comments
 (0)