Skip to content

add 16a4w_hqq quant mode #3752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: gh/cccclai/9/base
Choose a base branch
from
Open

Conversation

cccclai
Copy link
Contributor

@cccclai cccclai commented May 28, 2024

Stack from ghstack (oldest at bottom):

Prerequistie: install hqq following https://github.com/mobiusml/hqq

Step 1: use hqq to quantize weight to 4bit
Step 2: use static quant to quantize activation to 16bit

Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration

command:

python -m examples.models.llama2.eval_llama  -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth  --max_seq_len 129   -qmode 16a4w-hqq  --limit 5  2>&1 | tee hqq_16a4w.log

Differential Revision: D57849772

Prerequistie: install hqq following https://github.com/mobiusml/hqq

Step 1: use hqq to quantize weight to 4bit
Step 2: use static quant to quantize activation to 16bit

Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration

command:
```
python -m examples.models.llama2.eval_llama  -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth  --max_seq_len 129   -qmode 16a4w-hqq  --limit 5  2>&1 | tee hqq_16a4w.log
```

Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/)

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/3752

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 9d92858 with merge base c665c17 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 28, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D57849772

cccclai added a commit that referenced this pull request May 28, 2024
Prerequistie: install hqq following https://github.com/mobiusml/hqq

Step 1: use hqq to quantize weight to 4bit
Step 2: use static quant to quantize activation to 16bit

Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration

command:
```
python -m examples.models.llama2.eval_llama  -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth  --max_seq_len 129   -qmode 16a4w-hqq  --limit 5  2>&1 | tee hqq_16a4w.log
```

Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/)

ghstack-source-id: 227884756
Pull Request resolved: #3752
Prerequistie: install hqq following https://github.com/mobiusml/hqq

Step 1: use hqq to quantize weight to 4bit
Step 2: use static quant to quantize activation to 16bit

Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration

command:
```
python -m examples.models.llama2.eval_llama  -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth  --max_seq_len 129   -qmode 16a4w-hqq  --limit 5  2>&1 | tee hqq_16a4w.log
```

Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D57849772

cccclai added a commit that referenced this pull request May 28, 2024
Pull Request resolved: #3752

Prerequistie: install hqq following https://github.com/mobiusml/hqq

Step 1: use hqq to quantize weight to 4bit
Step 2: use static quant to quantize activation to 16bit

Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration

command:
```
python -m examples.models.llama2.eval_llama  -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth  --max_seq_len 129   -qmode 16a4w-hqq  --limit 5  2>&1 | tee hqq_16a4w.log
```

Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/)
ghstack-source-id: 227950317
@cccclai
Copy link
Contributor Author

cccclai commented May 28, 2024

Needs to Install the latest hqq and torchao-nightly with

pip install git+https://github.com/mobiusml/hqq.git
pip install torchao-nightly

version reference

hqq                       0.1.7.post2
torchao-nightly           2024.5.19

Also I'll be keep updating this pr, and this commit: 5900db7 is the tested working one.

Prerequistie: install hqq following https://github.com/mobiusml/hqq

Step 1: use hqq to quantize weight to 4bit
Step 2: use static quant to quantize activation to 16bit

Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration

command:
```
python -m examples.models.llama2.eval_llama  -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth  --max_seq_len 129   -qmode 16a4w-hqq  --limit 5  2>&1 | tee hqq_16a4w.log
```

Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D57849772

cccclai added a commit that referenced this pull request May 28, 2024
Pull Request resolved: #3752

Prerequistie: install hqq following https://github.com/mobiusml/hqq

Step 1: use hqq to quantize weight to 4bit
Step 2: use static quant to quantize activation to 16bit

Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration

command:
```
python -m examples.models.llama2.eval_llama  -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth  --max_seq_len 129   -qmode 16a4w-hqq  --limit 5  2>&1 | tee hqq_16a4w.log
```
ghstack-source-id: 227952016

Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/)
@cccclai cccclai marked this pull request as draft May 28, 2024 20:04
Prerequistie: install hqq following https://github.com/mobiusml/hqq

Step 1: use hqq to quantize weight to 4bit
Step 2: use static quant to quantize activation to 16bit

Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration

command:
```
python -m examples.models.llama2.eval_llama  -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth  --max_seq_len 129   -qmode 16a4w-hqq  --limit 5  2>&1 | tee hqq_16a4w.log
```

Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D57849772

cccclai added a commit that referenced this pull request May 28, 2024
Pull Request resolved: #3752

Prerequistie: install hqq following https://github.com/mobiusml/hqq

Step 1: use hqq to quantize weight to 4bit
Step 2: use static quant to quantize activation to 16bit

Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration

command:
```
python -m examples.models.llama2.eval_llama  -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth  --max_seq_len 129   -qmode 16a4w-hqq  --limit 5  2>&1 | tee hqq_16a4w.log
```
ghstack-source-id: 228003732

Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/)
Prerequistie: install hqq following https://github.com/mobiusml/hqq

Step 1: use hqq to quantize weight to 4bit
Step 2: use static quant to quantize activation to 16bit

Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration

command:
```
python -m examples.models.llama2.eval_llama  -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth  --max_seq_len 129   -qmode 16a4w-hqq  --limit 5  2>&1 | tee hqq_16a4w.log
```

Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D57849772

cccclai added a commit that referenced this pull request May 29, 2024
Pull Request resolved: #3752

Prerequistie: install hqq following https://github.com/mobiusml/hqq

Step 1: use hqq to quantize weight to 4bit
Step 2: use static quant to quantize activation to 16bit

Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration

command:
```
python -m examples.models.llama2.eval_llama  -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth  --max_seq_len 129   -qmode 16a4w-hqq  --limit 5  2>&1 | tee hqq_16a4w.log
```
ghstack-source-id: 228051126

Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/)
@cccclai cccclai marked this pull request as ready for review May 29, 2024 05:41
@cccclai cccclai requested a review from jerryzh168 May 29, 2024 05:42
@@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shewu-quic
Copy link
Collaborator

Hi Chen,
Thanks for your amazing work.
I have a few questions.

  1. This PR seems to assume not use_kv_cache. May I know does it work in use_kv_cache mode?
  2. Why does the eager_eval_wrapper only inference once? Because I think in kv cache mode we expect to use past kv cache to predict next word, it will predict one by one token.
  3. I am tying hqq with our static llama. But calibration and evaluation is very slow. The bottleneck seems we predict one by one token not a sequence.

Really appreciate your sharing!

@cccclai
Copy link
Contributor Author

cccclai commented Jun 3, 2024

  1. This PR seems to assume not use_kv_cache. May I know does it work in use_kv_cache mode?

It should work with kv cache mode. It was just too slow when I started hqq at the beginning and I was trying to get some initial signal how the algo performs. The pr 3732 helped address perf for the kv cache version a bit.

  1. Why does the eager_eval_wrapper only inference once? Because I think in kv cache mode we expect to use past kv cache to predict next word, it will predict one by one token.

Oh I think this line only predicts the next token, given the prompt, and it's equivalent to the version without kv cache

  1. I am tying hqq with our static llama. But calibration and evaluation is very slow. The bottleneck seems we predict one by one token not a sequence.

How long does it take for you to apply this algo to stories? Just so we can use it for reference. When I apply this quant mode to llama2, if it runs on CPU, it takes two hours to finish. If it's on GPU, it takes a few minutes. Also, if calibration takes too long, make reduce sample to 5. I set it to 30 but 5 can be quite reasonable.

Also just a reminder, we need to set group_size=None explicty in the hqq. I'll add spin quant later.

@shewu-quic
Copy link
Collaborator

Oh I think this line only predicts the next token, given the prompt, and it's equivalent to the version without kv cache

Yes, from my observation too.

How long does it take for you to apply this algo to stories? Just so we can use it for reference. When I apply this quant mode to llama2, if it runs on CPU, it takes two hours to finish.

Ahh, I observer that when I switched to the version without kv cache, running eval changed from taking 20 minutes to 30 seconds. Amazing!

If it's on GPU, it takes a few minutes. Also, if calibration takes too long, make reduce sample to 5. I set it to 30 but 5 can be quite reasonable.

Unfortunately, when I run it on GPU, it will OOM. Because my gpu, RTX 3080, only has 10G VRAM.

Also just a reminder, we need to set group_size=None explicty in the hqq. I'll add spin quant later.

Thanks for your sharing. I will give it a shot on my model.

@shewu-quic
Copy link
Collaborator

shewu-quic commented Jun 4, 2024

Update my experiment.

I set group_size to None, and the PPL really goes up.

Our llama with hqq 16a4w:
wikitext: {'word_perplexity,none': 66.4990777220096, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 2.2699800048977496, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 1.182679589603456, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

Baseline:
wikitext: {'word_perplexity,none': 15.705432743397127, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.7124018930827969, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.7760213355666792, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

@cccclai
Copy link
Contributor Author

cccclai commented Jun 4, 2024

Update my experiment.

I set group_size to None, and the PPL really goes up.

Our llama with hqq 16a4w: wikitext: {'word_perplexity,none': 66.4990777220096, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 2.2699800048977496, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 1.182679589603456, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

Baseline: wikitext: {'word_perplexity,none': 15.705432743397127, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.7124018930827969, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.7760213355666792, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

Thanks! Yeah that's aligned with the observation on my side. How long does it take to run? Also I expect this model can generate reasonable response, likely will tend to stuttering though.

@shewu-quic
Copy link
Collaborator

shewu-quic commented Jun 5, 2024

Thanks! Yeah that's aligned with the observation on my side. How long does it take to run? Also I expect this model can generate reasonable response, likely will tend to stuttering though.

When I update to this version, it takes about 2 hours on CPU

Regarding the issue with mutable buffer, when we use your version, a lot of partitions will be generated because QNN does not support slice_scatter op. In our version, there will be no partitions. The difference is that we update KV Cache outside of llama, and you update KV Cache in each attention layer. We are thinking about whether we can use some transforms from your version to update KV Cache at the end, so that we can do research quantization in the same llama.

  %executorch_call_delegate_11 : [num_users=7] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_11, %aten_index_tensor, %aten_index_tensor_1, %getitem_73, %aten_slice_scatter_default_42, %aten_slice_scatter_default_43, %getitem_69, %getitem_70, %aten_index_tensor_12, %getitem_72), kwargs = {})
  %getitem_76 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_11, 0), kwargs = {})
  %getitem_77 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_11, 1), kwargs = {})
  %getitem_78 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_11, 2), kwargs = {})
  %getitem_79 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_11, 3), kwargs = {})
  %getitem_80 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_11, 4), kwargs = {})
  %getitem_81 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_11, 5), kwargs = {})
  %getitem_82 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_11, 6), kwargs = {})
  %aten_index_tensor_13 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.index.Tensor](args = (%getitem_78, [None, None, %input_pos]), kwargs = {})
  %aten_index_put_default_22 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.index_put.default](args = (%b_layers_11_attention_sdpa_kv_cache_v_cache, [None, None, %input_pos], %getitem_81), kwargs = {})
  %aten_index_put_default_23 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.index_put.default](args = (%b_layers_11_attention_sdpa_kv_cache_k_cache, [None, None, %input_pos], %getitem_82), kwargs = {})
  %aten_slice_scatter_default_44 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.slice_scatter.default](args = (%b_layers_11_attention_sdpa_kv_cache_v_cache, %aten_index_put_default_22, 1, 0, 9223372036854775807), kwargs = {})
  %aten_slice_scatter_default_45 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.slice_scatter.default](args = (%b_layers_11_attention_sdpa_kv_cache_k_cache, %aten_index_put_default_23, 1, 0, 9223372036854775807), kwargs = {})
  %aten_slice_scatter_default_46 : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.slice_scatter.default](args = (%b_layers_11_attention_sdpa_kv_cache_v_cache, %aten_slice_scatter_default_44, 0, 0, 9223372036854775807), kwargs = {})
  %aten_slice_scatter_default_47 : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.slice_scatter.default](args = (%b_layers_11_attention_sdpa_kv_cache_k_cache, %aten_slice_scatter_default_45, 0, 0, 9223372036854775807), kwargs = {})
  %lowered_module_12 : [num_users=1] = get_attr[target=lowered_module_12]

@cccclai
Copy link
Contributor Author

cccclai commented Jun 5, 2024

ssue with mutable buffer, when we use your version, a lot of partitions will be generated because QNN does not support slice_scatter op. In our version, there will be no partitions. The difference is that we update KV Cache outside of llama, and you update KV Cache in each attention layer. We are thinking about whether we can use some transforms from your version to up

Actually can you try this? #3786 It will get rid of slice_scatter

@shewu-quic
Copy link
Collaborator

shewu-quic commented Jun 6, 2024

Actually can you try this? #3786 It will get rid of slice_scatter

Great! I have tried, but we still get {num_layers} partitions due to index_put op. I need to figure out how to support by qnn.

But I could get one partition when I use concat to update kv cache.
I will bring up another PR later.
image

BTW our llama with hqq 16a4w, I will get the following result

[INFO 2024-06-06 00:35:42,077 eval_llama.py:199] Result with static llama model in eager mode: wikitext: {'word_perplexity,none': 66.4990777220096, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 2.2699800048977496, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 1.182679589603456, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}
Prompt:
Once
Result:
Once, I was in my 19 years old, and I was working in a big corporation. nobody is better than you.
I, 19 years old woman, and I was working in a big corporation.
I, 20 years old woman, and I was working in a big corporation.

@cccclai
Copy link
Contributor Author

cccclai commented Jun 6, 2024

Actually can you try this? #3786 It will get rid of slice_scatter

Great! I have tried, but we still get {num_layers} partitions due to index_put op. I need to figure out how to support by qnn.

But I could get one partition when I use concat to update kv cache. I will bring up another PR later. image

BTW our llama with hqq 16a4w, I will get the following result

[INFO 2024-06-06 00:35:42,077 eval_llama.py:199] Result with static llama model in eager mode: wikitext: {'word_perplexity,none': 66.4990777220096, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 2.2699800048977496, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 1.182679589603456, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}
Prompt:
Once
Result:
Once, I was in my 19 years old, and I was working in a big corporation. nobody is better than you.
I, 19 years old woman, and I was working in a big corporation.
I, 20 years old woman, and I was working in a big corporation.

Yeah 60+ perplexity number is not great...I expect it to say some readable words but the quality won't be great..

Using concat sounds fine, as long as we have one partition...regarding index_put, I actually try something lilke this, but it's still local. I think it's still good to have more operators supported.


# op_index_put.py

@register_node_visitor
class IndexPutVisitor(NodeVisitor):
    target = ["aten.index_put.default"]

    def __init__(self, *args) -> None:
        super().__init__(*args)


    def define_node(
        self,
        node: torch.fx.Node,
        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
    ) -> PyQnnWrapper.PyQnnOpWrapper:
        input_node = node.args[0]
        input_tensor = self.get_tensor(input_node, node)
        input_tensor_wrapper = self.define_tensor(
            input_node,
            input_tensor,
            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
            nodes_to_wrappers,
            is_input_tensor=True,
        )

        indicies_node = node.args[1]

        print("indicies_node: ", indicies_node)
        indices_list = [self.get_tensor(idx, idx) for idx in indicies_node if idx is not None]

        # Unpack the tuple
        indices_unpacked = [torch.flatten(idx) for idx in indices_list]
        # print("indices_unpacked: ", indices_unpacked)

        # Transpose and flatten
        indices_transposed = [torch.transpose(idx, 0, -1) for idx in indices_unpacked]

        # Convert to 1-D tensor
        indices_qnn = torch.cat(indices_transposed).unsqueeze(0)
        print("indices_qnn: ", indices_qnn, " indices_qnn.shape: ", indices_qnn.shape)
        indice_node = None
        for candidate_index_node in indicies_node:
            if candidate_index_node is not None and isinstance(candidate_index_node, torch.fx.Node):
                indice_node = candidate_index_node
                break
        indices_tensor_wrapper = self.define_tensor(
            indice_node,
            indices_qnn,
            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
            nodes_to_wrappers,
            is_input_tensor=True,
        )

        value_node = node.args[2]

        value_tensor = self.get_tensor(value_node, node)

        value_tensor_wrapper = self.define_tensor(
            value_node,
            value_tensor,
            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
            nodes_to_wrappers,
            is_input_tensor=True,
        )

        output_tensor = self.get_tensor(node, node)
        output_tensor_wrapper = self.define_tensor(
            node,
            output_tensor,
            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
            nodes_to_wrappers,
            is_input_tensor=False,
        )

        index_put_op = PyQnnWrapper.PyQnnOpWrapper(
            node.name,
            QNN_OP_PACKAGE_NAME_QTI_AISW,
            OpIndexPut.op_name,
        )
        axis = 0
        if len(node.args) == 2:
            axis = cast(int, node.args[1])

        if axis < 0:
            axis += node.meta["val"].dim()
        index_put_op.AddInputTensors([input_tensor_wrapper, indices_tensor_wrapper, value_tensor_wrapper])
        index_put_op.AddOutputTensors([output_tensor_wrapper])
        index_put_op.AddScalarParam(
            OpIndexPut.param_axis,
            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
            {"data": np.uint32(axis)},
        )

        return index_put_op


# In qnn_constants.py
@dataclass(init=False, frozen=True)
class OpIndexPut:
    op_name: str = "ScatterElements"
    param_axis: str = "axis"
    param_reduction: str = "reduction"

I expect to use ScatterElements but haven't had the chance to get it fully working

@cccclai
Copy link
Contributor Author

cccclai commented Jun 6, 2024

I can check what spin quant output given the prompt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants