Skip to content

Stacked cache mixtral. #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 51 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
b27f109
Almost working except mask, need to rebase to main to pick up the the…
wang2yn84 Jul 2, 2024
9fe9d08
Fixed the test_model_impl for llama, but test_llama_e2e is still fail…
wang2yn84 Jul 2, 2024
d7d6871
Adds lazy_cache_update and restructure the cache flags.
wang2yn84 Jul 3, 2024
3e84343
Disable all the prints. Fix create engine.
wang2yn84 Jul 3, 2024
6885b12
Fix typos and minor errors.
wang2yn84 Jul 3, 2024
c65744d
Fixes create engine.
wang2yn84 Jul 3, 2024
2094785
Adds new_cache_stacked and fixes cache update.
wang2yn84 Jul 4, 2024
2603e4f
Fix cache update when new_cach_stacked is False.
wang2yn84 Jul 4, 2024
8910fdf
Fix the cache manager and make unit tests pass except for 1.
wang2yn84 Jul 7, 2024
d903e54
Updates the exportable model to return cache.
wang2yn84 Jul 7, 2024
65c5197
Removed the fori loop in cache finalize. Moves the cache.finalize() t…
wang2yn84 Jul 8, 2024
9b8bf50
Try to use shard_map for cache update.
wang2yn84 Jul 8, 2024
ebb2fb8
Fix update single cache line in cache.finalize()
wang2yn84 Jul 8, 2024
88b348a
Adds int8 support.
wang2yn84 Jul 8, 2024
775afbe
Int8 left aligned lazy cache update working, performance still not go…
wang2yn84 Jul 9, 2024
151372e
Fix the stacked cache introduced in the previous couple of commits.
wang2yn84 Jul 9, 2024
a59a700
Put original ragged attention back.
wang2yn84 Jul 10, 2024
9575fd3
Add the original ragged attention kernel.
wang2yn84 Jul 10, 2024
cf79be5
Fixes the bf16/int8 cache stack.
wang2yn84 Jul 10, 2024
bbe8d90
Fix int8 stacked cache insertion in engine and finalization.
wang2yn84 Jul 10, 2024
e6b0cb9
Fixes int8 with lazy cache update.
wang2yn84 Jul 11, 2024
8329d50
Updates the int8 test.
wang2yn84 Jul 11, 2024
72c11c0
Fix the int8 ragged attention output sharding.
wang2yn84 Jul 11, 2024
15e1387
Fix group query attention broadcasting issue.
wang2yn84 Jul 11, 2024
c7b248a
Fix shard map input issue. Variables not listed as inputs are freezed…
wang2yn84 Jul 11, 2024
46791dc
Fix the flash attention mask shape; Fix the update single cache line …
wang2yn84 Jul 12, 2024
51a5f0a
Adds the kv cache test.
wang2yn84 Jul 12, 2024
0f0deab
Replace quantized cache "pos" with "input_pos" to align with bf16 cac…
wang2yn84 Jul 12, 2024
c5335b0
Fix prefill cache insertion issue for stacked cache; Changes reduce d…
wang2yn84 Jul 13, 2024
53bc76a
Adds lazy cache update with generate cache stacked new cache unstacke…
wang2yn84 Jul 15, 2024
a92191e
Fix the shard map sharding for stacked generate cache and unstacked n…
wang2yn84 Jul 15, 2024
6e1b35c
Using Jax API to slicing instead of Pytorch index slicing.
wang2yn84 Jul 15, 2024
c2c3103
Adds stacked cache support in ragged attention reference kernel.
wang2yn84 Jul 16, 2024
08b63aa
Adds stacked cache support for the modified ragged kernel.
wang2yn84 Jul 16, 2024
86e0c86
Llama2 70b int8 optimization done. Output not correct yet.
wang2yn84 Jul 16, 2024
3e32dcb
Remove testing temp output files.
wang2yn84 Jul 16, 2024
c52dd28
Fix the llama 70b output accuracy resulting from gqa.
wang2yn84 Jul 16, 2024
90655d3
Fixes the attention output slicing issue when not using flash attenti…
wang2yn84 Jul 17, 2024
b28d3c1
Fix the pallas kernel OOB issue
wang2yn84 Jul 18, 2024
2dffb49
Fix tests; Fix lint issues;
wang2yn84 Jul 18, 2024
5855b3d
Fix the interactive script.
wang2yn84 Jul 18, 2024
dc0921e
Add mlperf benchmark scripts in-tree. (#148)
qihqi Jul 15, 2024
41f59a1
Fix lint errors.
wang2yn84 Jul 19, 2024
d78a9bb
Fix errors.
wang2yn84 Jul 19, 2024
5035ce9
Fix the comments.
wang2yn84 Jul 19, 2024
e36833d
Fix based on comments; Fix all the unit tests.
wang2yn84 Jul 19, 2024
d263ff5
Fix the remaining pylint errors.
wang2yn84 Jul 19, 2024
fadc777
Default ring buffer back to true so that all the test_run_server and …
wang2yn84 Jul 19, 2024
703d71f
Fix all the lint errors.
wang2yn84 Jul 19, 2024
88f9ac8
Fix run_offline script.
wang2yn84 Jul 20, 2024
a5e47f8
Fix lint errors.
wang2yn84 Jul 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions benchmarks/mixtral_offline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
CACHE_LENGTH=1024
INPUT_SIZE=512
OUTPUT_SIZE=1024
BATCH_SIZE=512
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/

pushd ..
python -m benchmarks.run_offline \
--model_name=mixtral \
--batch_size=$BATCH_SIZE \
--max_cache_length=$CACHE_LENGTH \
--max_decode_length=$OUTPUT_SIZE \
--context_length=$INPUT_SIZE \
--checkpoint_path=$CHECKPOINT_PATH/model.safetensors \
--tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \
--quantize_weights=1 \
--quantize_type=int8_per_channel \
--quantize_kv_cache=1 \
--profiling_output=/mnt/disks/hanq/mixtral-profiles
popd
34 changes: 26 additions & 8 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file")


def run_prefill_time(engine, params, decode_state, seqlen):
def run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
"""Run prefill and measure time."""
metadata = engine.get_tokenizer()
tokenizer = engine.build_tokenizer(metadata)
Expand All @@ -53,15 +53,20 @@ def run_prefill_time(engine, params, decode_state, seqlen):
nums = 5
start = time.perf_counter()
for i in range(nums):
if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started:
jax.profiler.start_trace(FLAGS.profiling_output)
profiler_started = True

prefill_result, _ = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
decode_state = engine.insert(
prefill_result, decode_state, slot=jnp.int32(i)
)
jax.block_until_ready(decode_state)

end = time.perf_counter()
return (end - start) / nums, decode_state
return (end - start) / nums, decode_state, profiler_started


MAXTEXT_PREFILL = {
Expand All @@ -86,9 +91,10 @@ def main(argv):
prefill_times = {}

decode_state = engine.init_decode_state()
profiler_started = False
for batch, _ in MAXTEXT_PREFILL.items():
runtime, decode_state = run_prefill_time(
engine, params, decode_state, batch
runtime, decode_state, profiler_started = run_prefill_time(
engine, params, decode_state, batch, profiler_started
)
prefill_times[batch] = runtime

Expand All @@ -103,10 +109,12 @@ def main(argv):

profiling_output = FLAGS.profiling_output
print("======= decode starting ===")

dec_times = []
for i in range(10):
if profiling_output and i == 7:
if profiling_output and i == 7 and not profiler_started:
jax.profiler.start_trace(profiling_output)
profiler_started = True
start = time.perf_counter()
# pylint: disable-next=all
decode_state, sampled_tokens = engine.generate(params, decode_state)
Expand All @@ -116,14 +124,24 @@ def main(argv):
dec_times.append(end - start)
print(i, "decode time", (end - start))

if profiling_output:
if profiler_started:
jax.profiler.stop_trace()

print("prefill ", prefill_times)
print("decode", sum(dec_times) / 10)
avg_decode_times = sum(dec_times[2:]) / len(dec_times[2:])
print("decode", avg_decode_times)

prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()}
decode_time_ms = sum(dec_times) * 1000 / 10 / FLAGS.batch_size
decode_time_ms = sum(dec_times[2:]) * 1000 / 8

largest_prefill = max(prefill_times.items())
print("MAX tokens:", FLAGS.batch_size / avg_decode_times)

time2 = (FLAGS.batch_size * FLAGS.max_decode_length) / (
FLAGS.batch_size * largest_prefill[1]
+ FLAGS.max_decode_length * avg_decode_times
)
print("MAX tokens 2:", time2)

sharegpt_path = FLAGS.sharegpt_path
if sharegpt_path:
Expand Down
Loading
Loading