Skip to content

Qualcomm AI Engine Direct - The performance issue about mutable buffer #6493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

shewu-quic
Copy link
Collaborator

@shewu-quic shewu-quic commented Oct 25, 2024

Summary:

Currently, we are running llama 3.2 1B instruct with a sequence length of 512 using llama_main’s profiling. The generated token rate is approximately 19 tokens per second. However, if we only consider the execution time of QNN, it is about 56 tokens per second.

  1. Cherry pick Quantize I/O
    • Improve the performance about 5%
  2. Cherry pick MHA to SHA
    • Although it could optimize the total cycle of matmul op about 50%, it would be added more overhead for some binary ops (which are in rotatory embed). Especially for smaller models, the burden outweighs the optimization benefits. Therefore, we continue to use MHA for llama3.2 instruct 1B and are investigating ways to eliminate these additional burdens.
  3. Pre-compute scale_factor
  4. Eliminate redundancy transpose and reshape op for feed_forward
  5. Change the data type of input token from int64 to int32.
    • Allows the embedding op to run into backend optimizations.
  6. Delegate copy op to try eliminating the overhead in runtime
    • This cannot completely solve the problem, and it is currently the main reason for poor performance.

Issue description

As far as I know, to support the mutable buffer feature, to_executorch seems to insert copy operations for mutable buffers. These copy operations cause some overhead at runtime, according to my profiling results below.
However, I have observed that Xnnpack does not seem to have this issue. Is there any way to avoid these copy operations while ensuring the mutable buffer functions correctly

Time proportion llama 3.2 1B instruct, seq_len=128 llama 3.2 1B instruct, seq_len=512 llama 3.2 3B instruct, seq_len=128
Run in QnnExecuTorchBackend (us) 13138 (61%) 17086 (33%) ~33000 (50%)
Run copy op in CPU (us) 64*130 = 8320 (39%) 64*530=33920 (67%) ~32000 (50%)

Note that:

  1. Run on SM8650 and qnn version is 2.26
  2. The 64 copy operations are for the key and value caches across 32 layers in llama 3.2 1B instruct model.
  3. Export command is
 python -m examples.models.llama.export_llama -t /path/to/Llama3.2-1B-Instruct/tokenizer.model -p /path/to/Llama3.2-1B-Instruct/params.json -c /path/to/Llama3.2-1B-Instruct/consolidated.00.pth --use_kv_cache --qnn --disable_dynamic_shape --pt2e_quantize qnn_16a4w --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128  --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --max_seq_length 128

Copy link

pytorch-bot bot commented Oct 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/6493

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 25, 2024
@shewu-quic shewu-quic changed the title Qualcomm AI Engine Direct - Optimize the performance for llama 3.2 1B instruct Qualcomm AI Engine Direct - The performance issue about mutable buffer Oct 25, 2024
@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@kimishpatel
Copy link
Contributor

I have a couple of questions/suggestions.

For each optimziation that was applied what is the benefit % and in terms of num tok/sec. For example, this change " Change the data type of input token from int64 to int32." seems bit from model's API perspective.

Regarding mutable buffer copy at the end. This can be avoided only under the assumption that delegate manages and consumes the mutable buffer entirely. So if you manage your own kv cache and consume such mutable buffer, than you can avoid this copy

@kimishpatel
Copy link
Contributor

I have a couple of questions/suggestions.

For each optimziation that was applied what is the benefit % and in terms of num tok/sec. For example, this change " Change the data type of input token from int64 to int32." seems bit from model's API perspective.

Regarding mutable buffer copy at the end. This can be avoided only under the assumption that delegate manages and consumes the mutable buffer entirely. So if you manage your own kv cache and consume such mutable buffer, than you can avoid this copy

see this pr for allowing delegates to consume buffers like kv cache #4830

@cccclai
Copy link
Contributor

cccclai commented Oct 28, 2024

Is copy op faster or this option faster QNN_PROPERTY_TENSOR_SUPPORT_UPDATEABLE_STATIC_TENSORS?

@shewu-quic
Copy link
Collaborator Author

Is copy op faster or this option faster QNN_PROPERTY_TENSOR_SUPPORT_UPDATEABLE_STATIC_TENSORS?

I think that QNN_PROPERTY_TENSOR_SUPPORT_UPDATEABLE_STATIC_TENSORS is used for Lora.
Maybe I could delegate copy op to add zero, and it should be faster.

@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Oct 28, 2024

Thanks for your suggestion.

For each optimziation that was applied what is the benefit % and in terms of num tok/sec. For example, this change " Change the data type of input token from int64 to int32." seems bit from model's API perspective.

Got it. I will add the percentage for each item.
BTW, I think "Change the data type of input token from int64 to int32" is the one that optimizes performance the most, from 29.x tok/sec -> 48.x tok/sec.
Int32 is Qnn HTP friendly. It will significantly speed up qnn embedding operations.

Regarding mutable buffer copy at the end. This can be avoided only under the assumption that delegate manages and consumes the mutable buffer entirely. So if you manage your own kv cache and consume such mutable buffer, than you can avoid this copy.

Does it mean we need to in-place update kv-cache?
If so, I don't think it could be supported by Qnn HTP.

@kimishpatel
Copy link
Contributor

kimishpatel commented Oct 28, 2024 via email

@cccclai
Copy link
Contributor

cccclai commented Oct 28, 2024

Yeah since this op is insert inside to executorch, how about we consume this state update part in qnn backend, then runs insert_write_back_for_buffers_pass inside qnn backend such that we will see the copy op inside?

@shewu-quic
Copy link
Collaborator Author

Yeah since this op is insert inside to executorch, how about we consume this state update part in qnn backend, then runs insert_write_back_for_buffers_pass inside qnn backend such that we will see the copy op inside?

I don’t quite understand. Could you explain a bit more?

Because copy ops have been inserted inside convert_pt2e. If possible, we delegate these copy ops and don't insert again inside to_executorch.

@cccclai
Copy link
Contributor

cccclai commented Oct 29, 2024

Yeah since this op is insert inside to executorch, how about we consume this state update part in qnn backend, then runs insert_write_back_for_buffers_pass inside qnn backend such that we will see the copy op inside?

I don’t quite understand. Could you explain a bit more?

Because copy ops have been inserted inside convert_pt2e. If possible, we delegate these copy ops and don't insert again inside to_executorch.

oh wait, can you share a bit how does the graph look like after convert_pt2e? I thought they're still in place ops. Is it related to previous unexpected graph #4627 and the hack was inserted somewhere around pt2e?

I thought it's

to_edge()..
                graph -> index_put_ (in place op)
to_backend()
                graph -> index_put_ (in place op)

to_executorch()
                graph -> detect in place op `index_put_`, insert copy op

@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Oct 29, 2024

oh wait, can you share a bit how does the graph look like after convert_pt2e? I thought they're still in place ops. Is it related to previous unexpected graph #4627 and the hack was inserted somewhere around pt2e?

May I know how could I check it whether is in place ops or not?
I see that it always uses index_put.
But I noticed that there seems to be different behavior in to_executorch in EdgeProgramManager and to_executorch in ExirExportedProgram. In export_llama.py, EdgeProgramManager is used. Its behavior is as follows:

to_backend
image
graph_signature

ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_getattr_l__self___layers___0___attention_sdpa_kv_cache_past_k_caches'), target='layers.0.attention.SDPA.kv_cache.past_k_caches', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_getattr_l__self___layers___0___attention_sdpa_kv_cache_past_v_caches'), target='layers.0.attention.SDPA.kv_cache.past_v_caches', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='tokens'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input_pos'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_1'), target='layers.0.attention.SDPA.kv_cache.past_k_caches'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem'), target='layers.0.attention.SDPA.kv_cache.past_v_caches'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_2'), target=None)])

to_executorch
image
graph_signature

ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_getattr_l__self___layers___0___attention_sdpa_kv_cache_past_k_caches'), target='layers.0.attention.SDPA.kv_cache.past_k_caches', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_getattr_l__self___layers___0___attention_sdpa_kv_cache_past_v_caches'), target='layers.0.attention.SDPA.kv_cache.past_v_caches', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='tokens'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input_pos'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='copy__default'), target='layers.0.attention.SDPA.kv_cache.past_k_caches'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='copy__default_1'), target='layers.0.attention.SDPA.kv_cache.past_v_caches'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_2'), target=None)])

However, in our unit test, ExirExportedProgram is used, and its behavior is as follows:

to_backend
image
graph_sigature

ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_k_cache'), target='k_cache', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input_pos'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='k_val'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem'), target='k_cache'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])

to_executorch
image
graph_sigature

ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_k_cache'), target='k_cache', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input_pos'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='k_val'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem'), target='k_cache'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])

@shewu-quic
Copy link
Collaborator Author

Oops, in our unit test, it doesn't seem to work with mutable buffers....

@shewu-quic
Copy link
Collaborator Author

Hi @cccclai,
In my understanding, QNN should not support inplace operators. Under this assumption, we want to avoid copying.
Is there a way to make the input and output share the same data_ptr for the mutable buffer? This seems to avoid the copy.

@kimishpatel
Copy link
Contributor

Oops, in our unit test, it doesn't seem to work with mutable buffers....

thats what i expect, because if we dont copy mutable buffer then it is functionally incorrect

@kimishpatel
Copy link
Contributor

Under this assumption

Can you manage the kv cache buffer inside qnn delegate of ET? Meaning not in qnn sdk but in qnn code inside backends/qualcomm inside ET. Here if you maintain kv cache as a buffer in qnn delegate runtime, then you can either do the same copy inside qnn delegate runtime, or have functional index_put alias to the same tensor to that index_put implmentation will modify the original buffer? But this has to be done on your end and not sure if this is a difficult change

@kimishpatel
Copy link
Contributor

So @cccclai and I were discussing. See if this is feasible

  1. Tag buffers to be consumed by delegate. THis will give buffers to be handled by delegate
  2. add support for owning and updating mutable buffers inside qnn delegate runtime. This is effictively the same as https://github.com/pytorch/executorch/blob/main/examples/qualcomm/oss_scripts/llama2/runner/runner.cpp#L176C1-L185C1.
  3. Now you ahve the capability of managing mutable buffers inside delegate. Delegate is responsible for copying data back for mutated state. This does not improve perf yet because we have just moved the copy inside delegate instead of having that in the graph.
  4. Now you can add llama specific optimization, where you can check with entry will be updated in k/v cache based on start_pos and seq_len and only update those positions. THis is a bit hacky and very specific to specific instance of llama runner.

@cccclai
Copy link
Contributor

cccclai commented Oct 29, 2024

Regarding these parts,

  1. Now you ahve the capability of managing mutable buffers inside delegate. Delegate is responsible for copying data back for mutated state. This does not improve perf yet because we have just moved the copy inside delegate instead of having that in the graph.
  2. Now you can add llama specific optimization, where you can check with entry will be updated in k/v cache based on start_pos and seq_len and only update those positions. THis is a bit hacky and very specific to specific instance of llama runner.

I was just chatting with Min and looks like DMA buf can be accessed from both QNN and cpu, can we use DMA buf for kv cache to reduce the copy, and perhaps as a way to in place update?

@kimishpatel
Copy link
Contributor

Regarding these parts,

  1. Now you ahve the capability of managing mutable buffers inside delegate. Delegate is responsible for copying data back for mutated state. This does not improve perf yet because we have just moved the copy inside delegate instead of having that in the graph.
  2. Now you can add llama specific optimization, where you can check with entry will be updated in k/v cache based on start_pos and seq_len and only update those positions. THis is a bit hacky and very specific to specific instance of llama runner.

I was just chatting with Min and looks like DMA buf can be accessed from both QNN and cpu, can we use DMA buf for kv cache to reduce the copy, and perhaps as a way to in place update?

This is for @shewu-quic

@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Oct 30, 2024

Thank you all for the discussion. I think I understand your points.

My initial idea yesterday was to see if we could use the memory planning pass during the AOT phase to ensure that the I/O of the mutable buffer uses the same address. This is because, for index_put, the I/O shape is actually the same. For the QNN delegate at runtime, as long as the same address is given for the mutable buffer, it can be directly updated. I will write a simple test to check it.

If the above is not feasible, I think, as discussed, we need to:

  1. Allow the buffer to be consumed by the delegate and handled internally within the delegate.
  2. Allocate buffer for the mutable buffer.

3 & 4. I think if the same memory address is used, there is no need for a copy-back operation. Of course, this memory address can be a shared buffer (DMA buf), which should also speed things up.

Would this be in line with your thoughts?

@shewu-quic
Copy link
Collaborator Author

I tested a simple model and attempted to hack the mutable buffer to set the same memory address. It seems to work functionally. Next, I will try to apply this to LLaMA. I hope it succeeds.

Simple model:

class IndexPut(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer(
            "k_cache",
            torch.zeros((1, 4, 12, 64), dtype=torch.float32),
        )

    def forward(self, input_pos, k_val):
        k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val)
        return k_out + torch.zeros((1, 4, 12, 64), dtype=torch.float32)

Sample inputs

sample_input = [
    (
        torch.tensor([2], dtype=torch.int32),
        torch.ones((1, 1, 12, 64),dtype=torch.float32),
    ),
    (
        torch.tensor([3], dtype=torch.int32),
        torch.ones((1, 1, 12, 64),dtype=torch.float32)
        )
]

Results

# first inference
tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          ...,
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]]])
# second inference
tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          ...,
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627]],

         [[0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          ...,
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627],
          [0.0627, 0.0627, 0.0627,  ..., 0.0627, 0.0627, 0.0627]]]])

@cccclai
Copy link
Contributor

cccclai commented Oct 30, 2024

Thank you for sharing the thoughts and provide the simple examples. I’m wondering if this small example uses the hack (use memory planning to point to the same address) or the list?

Memory planning is also our official support, maybe if you can share how you do this, we have a better picture. Likely we didn’t test memory planning with in place memory update but theoretically it can work

in the meanwhile, for the consuming mutable buffer solution, they are actually part of .mutable_buffers instead of .params. Hope it’s easier to handle if we are go for the second path

@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Oct 30, 2024

Thank you for sharing the thoughts and provide the simple examples. I’m wondering if this small example uses the hack (use memory planning to point to the same address) or the list?

I hacked here to set the same address for the mutable buffer I/O.

Memory planning is also our official support, maybe if you can share how you do this, we have a better picture. Likely we didn’t test memory planning with in place memory update but theoretically it can work

Actually, I have no idea to use memory planning pass to set the same address for the mutable buffer I/O. Do you have any idea or example to use it?

in the meanwhile, for the consuming mutable buffer solution, they are actually part of .mutable_buffers instead of .params. Hope it’s easier to handle if we are go for the second path

Thanks for your information.
I think this is on AOT, right. I need to figure out how to distinguish mutable buffer at runtime. For now, I just prefix "mutbuf_" in the tensor name to identify it at runtime.
image

@kimishpatel
Copy link
Contributor

@shewu-quic

I would prefer that we take the second route you mentioned where qnn delegate consumes mutable buffer. Issue with hacking around memory planning is that mutable buffer is not memory planned in a conventional sense. And any updates to the buffer via index_put cannot alias, output cannot point to the input buffer, because the input buffer is really graph input as you have shown in the figure above.
Besides this, it violates "no inplace mutation" invariant which may have larger implication and supporting this will need to be thought through more carefully. Hence it seems to me that the second approach is probably the best

@shewu-quic
Copy link
Collaborator Author

@shewu-quic

I would prefer that we take the second route you mentioned where qnn delegate consumes mutable buffer. Issue with hacking around memory planning is that mutable buffer is not memory planned in a conventional sense. And any updates to the buffer via index_put cannot alias, output cannot point to the input buffer, because the input buffer is really graph input as you have shown in the figure above. Besides this, it violates "no inplace mutation" invariant which may have larger implication and supporting this will need to be thought through more carefully. Hence it seems to me that the second approach is probably the best

Got it, thank you for your detailed explanation. I will try the second method first.

- Delegated mutable buffer in AOT
- Manage mutable buffer at runtime
@cccclai
Copy link
Contributor

cccclai commented Nov 6, 2024

Hi just wanted to communicate the potential impact from the enabling batch prefill, it is possible that kv cache will be explicit IO, can we also ensure we won't have the copy overhead when we enable batch prefill? Sorry there are many moving pieces

@shewu-quic
Copy link
Collaborator Author

Hi just wanted to communicate the potential impact from the enabling batch prefill, it is possible that kv cache will be explicit IO, can we also ensure we won't have the copy overhead when we enable batch prefill? Sorry there are many moving pieces

Do you mean kv cache will be the inputs and outputs of the model such as static llama?
That would be great.

@cccclai
Copy link
Contributor

cccclai commented Nov 6, 2024

Do you mean kv cache will be the inputs and outputs of the model such as static llama?

Yeah it will be similar, but for prefill when we output the kv cache, and for decode when we pass in kv cache, we still need to ensure it stays in HTP and there is no overhead for copy

@shewu-quic
Copy link
Collaborator Author

Got it. For decode, we still need to eliminate the overhead for copy.
I think my current mechanism only affect the mutable buffer.
For prefill. it is not the mutable buffer so it would not be affected.

But I have the good news and the bad news.
Good news is we could improve the performance by setting the same memory address for I/O of the kv cache

  • llama 3.2 1B with seq_len=512 on SM8650: 18.870145 -> 51.339524 tok/sec
  • llama 3.2 3B with seq_len=512 on SM8650: 17.337082 tok/sec

Bad news is still worse than Xnnpack.
We have observed large inference variance in QNN graph execution, which we suspect is due to the large output size. However, we are still gathering evidence to confirm this.

Therefore, we are attempting to reduce the output size and have reverted to using static Llama for verification. We hope to provide an update on the situation today

@shewu-quic shewu-quic closed this Feb 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants