diff --git a/.github/workflows/offline_perf.yaml b/.github/workflows/offline_perf.yaml new file mode 100644 index 00000000..8230af11 --- /dev/null +++ b/.github/workflows/offline_perf.yaml @@ -0,0 +1,66 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Offline Performance + +on: + pull_request: + +jobs: + py: + name: "Offline micro benchmark" + strategy: + matrix: + python-version: ['3.10'] + runs-on: self-hosted + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install Dependencies + run: | + python -m venv venv + source venv/bin/activate + source install_everything.sh + wget https://github.com/jqlang/jq/releases/download/jq-1.7.1/jq-linux-amd64 + chmod +x ./jq-linux-amd64 + env + python -c "import jax; print(jax.devices())" + - name: Run offlinebench + env: + JAX_PLATFORMS: tpu,cpu + HF_TOKEN : ${{ secrets.HF_TOKEN}} + run: | + set -euo pipefail + source venv/bin/activate + JAX_PLATFORMS=tpu,cpu python -m jetstream_pt.cli benchmark_offline --model_id meta-llama/Meta-Llama-3-8B-Instruct --quantize_weights=0 --override_batch_size=128 --benchmark_save_offline_result_to_file=result.md --internal_use_random_weights=True --hf_token=$HF_TOKEN + cat result.md | ./jq-linux-amd64 -Rsa . > output.txt + - name: Update result to PR + env: + URL: ${{ github.event.pull_request.comments_url }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + curl \ + -X POST \ + $URL \ + -H "Content-Type: application/json" \ + -H "Authorization: token $GITHUB_TOKEN" \ + --data "{ \"body\": $(cat output.txt) }" + diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 1fdc0cb7..7cc50639 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -92,11 +92,12 @@ def main(argv): decode_state = engine.init_decode_state() profiler_started = False - for batch, _ in MAXTEXT_PREFILL.items(): + for exp in range(4, 11): + seqlen = 2**exp runtime, decode_state, profiler_started = run_prefill_time( - engine, params, decode_state, batch, profiler_started + engine, params, decode_state, seqlen, profiler_started ) - prefill_times[batch] = runtime + prefill_times[seqlen] = runtime sampled_tokens_list = [] diff --git a/jetstream_pt/cli.py b/jetstream_pt/cli.py index ce49d552..2c0f85db 100644 --- a/jetstream_pt/cli.py +++ b/jetstream_pt/cli.py @@ -5,6 +5,7 @@ # import torch_xla2 first! import torch_xla2 # pylint: disable import jax +from jax import numpy as jnp from absl import app, flags from jetstream.engine import token_utils from jetstream.core import server_lib @@ -26,6 +27,11 @@ flags.DEFINE_integer("max_output_length", 1024, "The batch size") flags.DEFINE_integer("port", 9000, "port to listen on") flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool") +flags.DEFINE_string( + "benchmark_save_offline_result_to_file", + "", + "if set, then save the result to the given file name", +) def shard_weights(env, weights, weight_shardings): @@ -114,6 +120,45 @@ def _check_model_id(): sys.exit(1) +def _run_prefill_time( + pt_engine, params, decode_state, seqlen, profiler_started +): + """Run prefill and measure time.""" + metadata = pt_engine.get_tokenizer() + tokenizer = pt_engine.build_tokenizer(metadata) + + text = "This is a beautiful day" + tokens, true_length = tokenizer.encode( + text, is_bos=True, prefill_lengths=[seqlen] + ) + + for _ in range(3): + prefill_result, _ = pt_engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length + ) + decode_state = pt_engine.insert( + prefill_result, decode_state, slot=jnp.int32(1) + ) + + nums = 5 + start = time.perf_counter() + for i in range(nums): + if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started: + jax.profiler.start_trace(FLAGS.profiling_output) + profiler_started = True + + prefill_result, _ = pt_engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length + ) + decode_state = pt_engine.insert( + prefill_result, decode_state, slot=jnp.int32(i) + ) + jax.block_until_ready(decode_state) + + end = time.perf_counter() + return (end - start) / nums, decode_state, profiler_started + + def interactive(): """Run interactive""" _check_model_id() @@ -207,6 +252,100 @@ def interactive(): print(tokenizer.decode(sampled_tokens_list)) +def _save_benchmark_to_file(filename, prefill_times_ms, decode_time_ms): + lines = ( + [ + " # Offline benchmark numbers", + " ## Model: " + FLAGS.model_id, + f" ## Batch size: {FLAGS.override_batch_size}", + f" ## Quantize: {FLAGS.quantize_weights}", + " | | time (ms) |", + " |-------|-----------|", + ] + + [f"| Prefill {x} | {y} |" for x, y in prefill_times_ms.items()] + + [f"| Decode | {decode_time_ms} |"] + ) + with open(filename, "w", encoding="utf-8") as f: + f.write("\n".join(lines)) + f.flush() + + +def benchmark_offline(): + """function to run engine offline.""" + _check_model_id() + devices = server_lib.get_devices() + print(f"devices: {devices}") + pt_engine = create_engine(devices) + + start = time.perf_counter() + params = pt_engine.load_params() + print("Load params ", time.perf_counter() - start) + + prefill_times = {} + + decode_state = pt_engine.init_decode_state() + profiler_started = False + # 16 .. 1024 + for exp in range(4, 11): + batch = 2**exp + runtime, decode_state, profiler_started = _run_prefill_time( + pt_engine, params, decode_state, batch, profiler_started + ) + prefill_times[batch] = runtime + + sampled_tokens_list = [] + + for i in range(3): # warm up + # pylint: disable-next=all + decode_state, sampled_tokens = pt_engine.generate( + params=params, decode_state=decode_state + ) + sampled_tokens_list.append(sampled_tokens) + + profiling_output = FLAGS.profiling_output + print("======= decode starting ===") + + dec_times = [] + for i in range(10): + if profiling_output and i == 7 and not profiler_started: + jax.profiler.start_trace(profiling_output) + profiler_started = True + start = time.perf_counter() + # pylint: disable-next=all + decode_state, sampled_tokens = pt_engine.generate(params, decode_state) + jax.block_until_ready(decode_state) + sampled_tokens_list.append(sampled_tokens) + end = time.perf_counter() + dec_times.append(end - start) + print(i, "decode time", (end - start)) + + if profiler_started: + jax.profiler.stop_trace() + + print("prefill ", prefill_times) + avg_decode_times = sum(dec_times[2:]) / len(dec_times[2:]) + print("decode", avg_decode_times) + + prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()} + decode_time_ms = sum(dec_times[2:]) * 1000 / 8 + + largest_prefill = max(prefill_times.items()) + print("MAX tokens:", FLAGS.batch_size / avg_decode_times) + + time2 = (FLAGS.batch_size * FLAGS.max_decode_length) / ( + FLAGS.batch_size * largest_prefill[1] + + FLAGS.max_decode_length * avg_decode_times + ) + print("MAX tokens 2:", time2) + + if FLAGS.benchmark_save_offline_result_to_file: + _save_benchmark_to_file( + FLAGS.benchmark_save_offline_result_to_file, + prefill_times_ms, + decode_time_ms, + ) + + def main(): """Main function.""" @@ -221,6 +360,8 @@ def main_real(argv): serve() elif argv[1] == "interactive": interactive() + elif argv[1] == "benchmark_offline": + benchmark_offline() else: print( "Invalid arguments. please specify 'list', 'serve', or 'interactive'." diff --git a/jetstream_pt/fetch_models.py b/jetstream_pt/fetch_models.py index c3e23125..65a0f7bd 100644 --- a/jetstream_pt/fetch_models.py +++ b/jetstream_pt/fetch_models.py @@ -23,6 +23,11 @@ "Directory to store downloaded/converted weights", ) flags.DEFINE_string("hf_token", "", "huggingface token") +flags.DEFINE_bool( + "internal_use_random_weights", + False, + "Use random weights instead of HF weights. Testing only.", +) flags.DEFINE_integer( "override_max_cache_length", @@ -158,13 +163,23 @@ def _load_weights(directory): return state_dict +def _make_random_model_weights(model): + result = {} + for key, val in model.state_dict().items(): + new_weights = torch.rand(val.shape, dtype=val.dtype, device="cpu") + result[key] = new_weights + return result + + def instantiate_model_from_repo_id( repo_id, env, ): """Create model instance by hf model id.+""" model_dir = _hf_dir(repo_id) - if not os.path.exists(model_dir) or not os.listdir(model_dir): + if not FLAGS.internal_use_random_weights and ( + not os.path.exists(model_dir) or not os.listdir(model_dir) + ): # no weights has been downloaded _hf_download(repo_id, model_dir, FLAGS.hf_token) model_info = model_id_to_class.get(repo_id) @@ -172,9 +187,11 @@ def instantiate_model_from_repo_id( env.device = "meta" model = model_info.model_class.from_hf_model_id(repo_id, env) - weights = _load_weights(model_dir) - weights = model.convert_hf_weights(weights) - + if not FLAGS.internal_use_random_weights: + weights = _load_weights(model_dir) + weights = model.convert_hf_weights(weights) + else: + weights = _make_random_model_weights(model) model.load_state_dict(weights, assign=True, strict=False) return model