Skip to content

Add offline perf ci #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions .github/workflows/offline_perf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Offline Performance

on:
pull_request:

jobs:
py:
name: "Offline micro benchmark"
strategy:
matrix:
python-version: ['3.10']
runs-on: self-hosted
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: |
python -m venv venv
source venv/bin/activate
source install_everything.sh
wget https://github.com/jqlang/jq/releases/download/jq-1.7.1/jq-linux-amd64
chmod +x ./jq-linux-amd64
env
python -c "import jax; print(jax.devices())"
- name: Run offlinebench
env:
JAX_PLATFORMS: tpu,cpu
HF_TOKEN : ${{ secrets.HF_TOKEN}}
run: |
set -euo pipefail
source venv/bin/activate
JAX_PLATFORMS=tpu,cpu python -m jetstream_pt.cli benchmark_offline --model_id meta-llama/Meta-Llama-3-8B-Instruct --quantize_weights=0 --override_batch_size=128 --benchmark_save_offline_result_to_file=result.md --internal_use_random_weights=True --hf_token=$HF_TOKEN
cat result.md | ./jq-linux-amd64 -Rsa . > output.txt
- name: Update result to PR
env:
URL: ${{ github.event.pull_request.comments_url }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
curl \
-X POST \
$URL \
-H "Content-Type: application/json" \
-H "Authorization: token $GITHUB_TOKEN" \
--data "{ \"body\": $(cat output.txt) }"

7 changes: 4 additions & 3 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ def main(argv):

decode_state = engine.init_decode_state()
profiler_started = False
for batch, _ in MAXTEXT_PREFILL.items():
for exp in range(4, 11):
seqlen = 2**exp
runtime, decode_state, profiler_started = run_prefill_time(
engine, params, decode_state, batch, profiler_started
engine, params, decode_state, seqlen, profiler_started
)
prefill_times[batch] = runtime
prefill_times[seqlen] = runtime

sampled_tokens_list = []

Expand Down
141 changes: 141 additions & 0 deletions jetstream_pt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# import torch_xla2 first!
import torch_xla2 # pylint: disable
import jax
from jax import numpy as jnp
from absl import app, flags
from jetstream.engine import token_utils
from jetstream.core import server_lib
Expand All @@ -26,6 +27,11 @@
flags.DEFINE_integer("max_output_length", 1024, "The batch size")
flags.DEFINE_integer("port", 9000, "port to listen on")
flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool")
flags.DEFINE_string(
"benchmark_save_offline_result_to_file",
"",
"if set, then save the result to the given file name",
)


def shard_weights(env, weights, weight_shardings):
Expand Down Expand Up @@ -114,6 +120,45 @@ def _check_model_id():
sys.exit(1)


def _run_prefill_time(
pt_engine, params, decode_state, seqlen, profiler_started
):
"""Run prefill and measure time."""
metadata = pt_engine.get_tokenizer()
tokenizer = pt_engine.build_tokenizer(metadata)

text = "This is a beautiful day"
tokens, true_length = tokenizer.encode(
text, is_bos=True, prefill_lengths=[seqlen]
)

for _ in range(3):
prefill_result, _ = pt_engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
decode_state = pt_engine.insert(
prefill_result, decode_state, slot=jnp.int32(1)
)

nums = 5
start = time.perf_counter()
for i in range(nums):
if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started:
jax.profiler.start_trace(FLAGS.profiling_output)
profiler_started = True

prefill_result, _ = pt_engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
decode_state = pt_engine.insert(
prefill_result, decode_state, slot=jnp.int32(i)
)
jax.block_until_ready(decode_state)

end = time.perf_counter()
return (end - start) / nums, decode_state, profiler_started


def interactive():
"""Run interactive"""
_check_model_id()
Expand Down Expand Up @@ -207,6 +252,100 @@ def interactive():
print(tokenizer.decode(sampled_tokens_list))


def _save_benchmark_to_file(filename, prefill_times_ms, decode_time_ms):
lines = (
[
" # Offline benchmark numbers",
" ## Model: " + FLAGS.model_id,
f" ## Batch size: {FLAGS.override_batch_size}",
f" ## Quantize: {FLAGS.quantize_weights}",
" | | time (ms) |",
" |-------|-----------|",
]
+ [f"| Prefill {x} | {y} |" for x, y in prefill_times_ms.items()]
+ [f"| Decode | {decode_time_ms} |"]
)
with open(filename, "w", encoding="utf-8") as f:
f.write("\n".join(lines))
f.flush()


def benchmark_offline():
"""function to run engine offline."""
_check_model_id()
devices = server_lib.get_devices()
print(f"devices: {devices}")
pt_engine = create_engine(devices)

start = time.perf_counter()
params = pt_engine.load_params()
print("Load params ", time.perf_counter() - start)

prefill_times = {}

decode_state = pt_engine.init_decode_state()
profiler_started = False
# 16 .. 1024
for exp in range(4, 11):
batch = 2**exp
runtime, decode_state, profiler_started = _run_prefill_time(
pt_engine, params, decode_state, batch, profiler_started
)
prefill_times[batch] = runtime

sampled_tokens_list = []

for i in range(3): # warm up
# pylint: disable-next=all
decode_state, sampled_tokens = pt_engine.generate(
params=params, decode_state=decode_state
)
sampled_tokens_list.append(sampled_tokens)

profiling_output = FLAGS.profiling_output
print("======= decode starting ===")

dec_times = []
for i in range(10):
if profiling_output and i == 7 and not profiler_started:
jax.profiler.start_trace(profiling_output)
profiler_started = True
start = time.perf_counter()
# pylint: disable-next=all
decode_state, sampled_tokens = pt_engine.generate(params, decode_state)
jax.block_until_ready(decode_state)
sampled_tokens_list.append(sampled_tokens)
end = time.perf_counter()
dec_times.append(end - start)
print(i, "decode time", (end - start))

if profiler_started:
jax.profiler.stop_trace()

print("prefill ", prefill_times)
avg_decode_times = sum(dec_times[2:]) / len(dec_times[2:])
print("decode", avg_decode_times)

prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()}
decode_time_ms = sum(dec_times[2:]) * 1000 / 8

largest_prefill = max(prefill_times.items())
print("MAX tokens:", FLAGS.batch_size / avg_decode_times)

time2 = (FLAGS.batch_size * FLAGS.max_decode_length) / (
FLAGS.batch_size * largest_prefill[1]
+ FLAGS.max_decode_length * avg_decode_times
)
print("MAX tokens 2:", time2)

if FLAGS.benchmark_save_offline_result_to_file:
_save_benchmark_to_file(
FLAGS.benchmark_save_offline_result_to_file,
prefill_times_ms,
decode_time_ms,
)


def main():
"""Main function."""

Expand All @@ -221,6 +360,8 @@ def main_real(argv):
serve()
elif argv[1] == "interactive":
interactive()
elif argv[1] == "benchmark_offline":
benchmark_offline()
else:
print(
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
Expand Down
25 changes: 21 additions & 4 deletions jetstream_pt/fetch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
"Directory to store downloaded/converted weights",
)
flags.DEFINE_string("hf_token", "", "huggingface token")
flags.DEFINE_bool(
"internal_use_random_weights",
False,
"Use random weights instead of HF weights. Testing only.",
)

flags.DEFINE_integer(
"override_max_cache_length",
Expand Down Expand Up @@ -158,23 +163,35 @@ def _load_weights(directory):
return state_dict


def _make_random_model_weights(model):
result = {}
for key, val in model.state_dict().items():
new_weights = torch.rand(val.shape, dtype=val.dtype, device="cpu")
result[key] = new_weights
return result


def instantiate_model_from_repo_id(
repo_id,
env,
):
"""Create model instance by hf model id.+"""
model_dir = _hf_dir(repo_id)
if not os.path.exists(model_dir) or not os.listdir(model_dir):
if not FLAGS.internal_use_random_weights and (
not os.path.exists(model_dir) or not os.listdir(model_dir)
):
# no weights has been downloaded
_hf_download(repo_id, model_dir, FLAGS.hf_token)
model_info = model_id_to_class.get(repo_id)
assert model_info is not None

env.device = "meta"
model = model_info.model_class.from_hf_model_id(repo_id, env)
weights = _load_weights(model_dir)
weights = model.convert_hf_weights(weights)

if not FLAGS.internal_use_random_weights:
weights = _load_weights(model_dir)
weights = model.convert_hf_weights(weights)
else:
weights = _make_random_model_weights(model)
model.load_state_dict(weights, assign=True, strict=False)

return model
Expand Down
Loading