Skip to content

Optimize cache update. #151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
f61641f
Almost working except mask, need to rebase to main to pick up the the…
wang2yn84 Jul 2, 2024
1bbceff
Fixed the test_model_impl for llama, but test_llama_e2e is still fail…
wang2yn84 Jul 2, 2024
eca6de7
Adds lazy_cache_update and restructure the cache flags.
wang2yn84 Jul 3, 2024
841f393
Disable all the prints. Fix create engine.
wang2yn84 Jul 3, 2024
90ffbbb
Fix typos and minor errors.
wang2yn84 Jul 3, 2024
dd4de9e
Fixes create engine.
wang2yn84 Jul 3, 2024
91181cd
Adds new_cache_stacked and fixes cache update.
wang2yn84 Jul 4, 2024
50a83d4
Fix cache update when new_cach_stacked is False.
wang2yn84 Jul 4, 2024
0336fb5
Fix the cache manager and make unit tests pass except for 1.
wang2yn84 Jul 7, 2024
ba3d385
Updates the exportable model to return cache.
wang2yn84 Jul 7, 2024
c808ca8
Removed the fori loop in cache finalize. Moves the cache.finalize() t…
wang2yn84 Jul 8, 2024
e2874fc
Try to use shard_map for cache update.
wang2yn84 Jul 8, 2024
0015e90
Fix update single cache line in cache.finalize()
wang2yn84 Jul 8, 2024
7661bb2
Adds int8 support.
wang2yn84 Jul 8, 2024
c965a42
Int8 left aligned lazy cache update working, performance still not go…
wang2yn84 Jul 9, 2024
8af0474
Fix the stacked cache introduced in the previous couple of commits.
wang2yn84 Jul 9, 2024
db0e3a3
Put original ragged attention back.
wang2yn84 Jul 10, 2024
7572a11
Add the original ragged attention kernel.
wang2yn84 Jul 10, 2024
efeb5de
Fixes the bf16/int8 cache stack.
wang2yn84 Jul 10, 2024
ed1acff
Fix int8 stacked cache insertion in engine and finalization.
wang2yn84 Jul 10, 2024
9460f66
Fixes int8 with lazy cache update.
wang2yn84 Jul 11, 2024
f188086
Updates the int8 test.
wang2yn84 Jul 11, 2024
124bb71
Fix the int8 ragged attention output sharding.
wang2yn84 Jul 11, 2024
ffaba5a
Fix group query attention broadcasting issue.
wang2yn84 Jul 11, 2024
78789f1
Fix shard map input issue. Variables not listed as inputs are freezed…
wang2yn84 Jul 11, 2024
e4643ee
Fix the flash attention mask shape; Fix the update single cache line …
wang2yn84 Jul 12, 2024
ef0b148
Adds the kv cache test.
wang2yn84 Jul 12, 2024
02e2d0b
Replace quantized cache "pos" with "input_pos" to align with bf16 cac…
wang2yn84 Jul 12, 2024
65b19a8
Fix prefill cache insertion issue for stacked cache; Changes reduce d…
wang2yn84 Jul 13, 2024
a87c608
Adds lazy cache update with generate cache stacked new cache unstacke…
wang2yn84 Jul 15, 2024
3170ef2
Fix the shard map sharding for stacked generate cache and unstacked n…
wang2yn84 Jul 15, 2024
ee1c011
Using Jax API to slicing instead of Pytorch index slicing.
wang2yn84 Jul 15, 2024
e08f31f
Adds stacked cache support in ragged attention reference kernel.
wang2yn84 Jul 16, 2024
b8e6b85
Adds stacked cache support for the modified ragged kernel.
wang2yn84 Jul 16, 2024
394e666
Llama2 70b int8 optimization done. Output not correct yet.
wang2yn84 Jul 16, 2024
0f24b8e
Remove testing temp output files.
wang2yn84 Jul 16, 2024
f905860
Fix the llama 70b output accuracy resulting from gqa.
wang2yn84 Jul 16, 2024
58dda18
Fixes the attention output slicing issue when not using flash attenti…
wang2yn84 Jul 17, 2024
ba80c19
Fix the pallas kernel OOB issue
wang2yn84 Jul 18, 2024
fa0ad3f
Fix tests; Fix lint issues;
wang2yn84 Jul 18, 2024
4b6bfcb
Fix the interactive script.
wang2yn84 Jul 18, 2024
57cd1ed
Fix lint errors.
wang2yn84 Jul 19, 2024
1f51536
Fix errors.
wang2yn84 Jul 19, 2024
3893e50
Fix the comments.
wang2yn84 Jul 19, 2024
89c4e88
Fix based on comments; Fix all the unit tests.
wang2yn84 Jul 19, 2024
004269b
Fix the remaining pylint errors.
wang2yn84 Jul 19, 2024
d0777fd
Default ring buffer back to true so that all the test_run_server and …
wang2yn84 Jul 19, 2024
e99a815
Fix all the lint errors.
wang2yn84 Jul 19, 2024
223338f
Fix run_offline script.
wang2yn84 Jul 20, 2024
1444e07
Fix the ring buffer mode long latency issue.
wang2yn84 Jul 30, 2024
595ead2
Rebase to main.
wang2yn84 Aug 5, 2024
d14e7f5
Fix all the lint issues.
wang2yn84 Aug 6, 2024
62f3c51
Fix Ray engine crash on multihost (#164)
richardsliu Aug 1, 2024
743c0e5
Fix TPU head resource name for v4 and v5e (#165)
richardsliu Aug 1, 2024
784801f
Fixed exhausted bug between head and workers (#163)
FanhaiLu1 Aug 2, 2024
d318ce4
Fix test_run_server issue from fixing the lint; Fix run_interactive f…
wang2yn84 Aug 6, 2024
8b26e9f
Revert xla changes.
wang2yn84 Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ Note: Get address ip and port information from ray head.
Here is an example to run the server with ray for llama2 7B model:

```bash
export DISABLE_XLA2_PJRT_TEST="true"
python run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips=4 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml"
```

Expand Down
20 changes: 14 additions & 6 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file")


def run_prefill_time(engine, params, decode_state, seqlen):
def run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
"""Run prefill and measure time."""
metadata = engine.get_tokenizer()
tokenizer = engine.build_tokenizer(metadata)
Expand All @@ -53,15 +53,20 @@ def run_prefill_time(engine, params, decode_state, seqlen):
nums = 5
start = time.perf_counter()
for i in range(nums):
if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started:
jax.profiler.start_trace(FLAGS.profiling_output)
profiler_started = True

prefill_result, _ = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
decode_state = engine.insert(
prefill_result, decode_state, slot=jnp.int32(i)
)
jax.block_until_ready(decode_state)

end = time.perf_counter()
return (end - start) / nums, decode_state
return (end - start) / nums, decode_state, profiler_started


MAXTEXT_PREFILL = {
Expand All @@ -86,9 +91,10 @@ def main(argv):
prefill_times = {}

decode_state = engine.init_decode_state()
profiler_started = False
for batch, _ in MAXTEXT_PREFILL.items():
runtime, decode_state = run_prefill_time(
engine, params, decode_state, batch
runtime, decode_state, profiler_started = run_prefill_time(
engine, params, decode_state, batch, profiler_started
)
prefill_times[batch] = runtime

Expand All @@ -103,10 +109,12 @@ def main(argv):

profiling_output = FLAGS.profiling_output
print("======= decode starting ===")

dec_times = []
for i in range(10):
if profiling_output and i == 7:
if profiling_output and i == 7 and not profiler_started:
jax.profiler.start_trace(profiling_output)
profiler_started = True
start = time.perf_counter()
# pylint: disable-next=all
decode_state, sampled_tokens = engine.generate(params, decode_state)
Expand All @@ -116,7 +124,7 @@ def main(argv):
dec_times.append(end - start)
print(i, "decode time", (end - start))

if profiling_output:
if profiler_started:
jax.profiler.stop_trace()

print("prefill ", prefill_times)
Expand Down
Loading
Loading