Skip to content

Commit ee040a4

Browse files
wang2yn84richardsliuFanhaiLu1
authored
Optimize cache update. (#151)
* Almost working except mask, need to rebase to main to pick up the the ring buffer support then fix the mask. Int8 updates also included but not tested. * Fixed the test_model_impl for llama, but test_llama_e2e is still failing. * Adds lazy_cache_update and restructure the cache flags. * Disable all the prints. Fix create engine. * Fix typos and minor errors. * Fixes create engine. * Adds new_cache_stacked and fixes cache update. * Fix cache update when new_cach_stacked is False. * Fix the cache manager and make unit tests pass except for 1. * Updates the exportable model to return cache. * Removed the fori loop in cache finalize. Moves the cache.finalize() to the end of existing cache attention. * Try to use shard_map for cache update. * Fix update single cache line in cache.finalize() * Adds int8 support. * Int8 left aligned lazy cache update working, performance still not good enough. * Fix the stacked cache introduced in the previous couple of commits. * Put original ragged attention back. * Add the original ragged attention kernel. * Fixes the bf16/int8 cache stack. * Fix int8 stacked cache insertion in engine and finalization. * Fixes int8 with lazy cache update. * Updates the int8 test. * Fix the int8 ragged attention output sharding. * Fix group query attention broadcasting issue. * Fix shard map input issue. Variables not listed as inputs are freezed into jit function. * Fix the flash attention mask shape; Fix the update single cache line quant version * Adds the kv cache test. * Replace quantized cache "pos" with "input_pos" to align with bf16 cache. Fix the kv cache quantization test. * Fix prefill cache insertion issue for stacked cache; Changes reduce dim for quantization from 1,3 to -3,-1 to make it more robust; * Adds lazy cache update with generate cache stacked new cache unstacked for performance validation. * Fix the shard map sharding for stacked generate cache and unstacked new cache. * Using Jax API to slicing instead of Pytorch index slicing. * Adds stacked cache support in ragged attention reference kernel. * Adds stacked cache support for the modified ragged kernel. * Llama2 70b int8 optimization done. Output not correct yet. * Remove testing temp output files. * Fix the llama 70b output accuracy resulting from gqa. * Fixes the attention output slicing issue when not using flash attention. Refactor to use only 1 flash attention kernel. Changes the modified ring buffer ragged attention kernel with quantization, layer, etc. * Fix the pallas kernel OOB issue * Fix tests; Fix lint issues; * Fix the interactive script. * Fix lint errors. * Fix errors. * Fix the comments. * Fix based on comments; Fix all the unit tests. * Fix the remaining pylint errors. * Default ring buffer back to true so that all the test_run_server and run_interactive in CPU mode can work. When we default ring buffer to false, should add additional flags to run_interactive CI to set test mode to true so that pallas kernel can run. * Fix all the lint errors. * Fix run_offline script. * Fix the ring buffer mode long latency issue. * Rebase to main. * Fix all the lint issues. * Fix Ray engine crash on multihost (#164) * Fix TPU head resource name for v4 and v5e (#165) * Fix TPU head resource name for v4 and v5e * fix format * Fixed exhausted bug between head and workers (#163) * add xla2 fix * update jax version * revert jax TPU version * Fix test_run_server issue from fixing the lint; Fix run_interactive from merge; Fix lints; * Revert xla changes. --------- Co-authored-by: Richard Liu <[email protected]> Co-authored-by: Fanhai Lu <[email protected]>
1 parent dcc7b27 commit ee040a4

24 files changed

+1743
-346
lines changed

benchmarks/run_offline.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file")
3333

3434

35-
def run_prefill_time(engine, params, decode_state, seqlen):
35+
def run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
3636
"""Run prefill and measure time."""
3737
metadata = engine.get_tokenizer()
3838
tokenizer = engine.build_tokenizer(metadata)
@@ -53,15 +53,20 @@ def run_prefill_time(engine, params, decode_state, seqlen):
5353
nums = 5
5454
start = time.perf_counter()
5555
for i in range(nums):
56+
if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started:
57+
jax.profiler.start_trace(FLAGS.profiling_output)
58+
profiler_started = True
59+
5660
prefill_result, _ = engine.prefill(
5761
params=params, padded_tokens=tokens, true_length=true_length
5862
)
5963
decode_state = engine.insert(
6064
prefill_result, decode_state, slot=jnp.int32(i)
6165
)
6266
jax.block_until_ready(decode_state)
67+
6368
end = time.perf_counter()
64-
return (end - start) / nums, decode_state
69+
return (end - start) / nums, decode_state, profiler_started
6570

6671

6772
MAXTEXT_PREFILL = {
@@ -86,9 +91,10 @@ def main(argv):
8691
prefill_times = {}
8792

8893
decode_state = engine.init_decode_state()
94+
profiler_started = False
8995
for batch, _ in MAXTEXT_PREFILL.items():
90-
runtime, decode_state = run_prefill_time(
91-
engine, params, decode_state, batch
96+
runtime, decode_state, profiler_started = run_prefill_time(
97+
engine, params, decode_state, batch, profiler_started
9298
)
9399
prefill_times[batch] = runtime
94100

@@ -103,10 +109,12 @@ def main(argv):
103109

104110
profiling_output = FLAGS.profiling_output
105111
print("======= decode starting ===")
112+
106113
dec_times = []
107114
for i in range(10):
108-
if profiling_output and i == 7:
115+
if profiling_output and i == 7 and not profiler_started:
109116
jax.profiler.start_trace(profiling_output)
117+
profiler_started = True
110118
start = time.perf_counter()
111119
# pylint: disable-next=all
112120
decode_state, sampled_tokens = engine.generate(params, decode_state)
@@ -116,7 +124,7 @@ def main(argv):
116124
dec_times.append(end - start)
117125
print(i, "decode time", (end - start))
118126

119-
if profiling_output:
127+
if profiler_started:
120128
jax.profiler.stop_trace()
121129

122130
print("prefill ", prefill_times)

0 commit comments

Comments
 (0)