Skip to content

RuntimeError: Missing out variants: {'quantized_decomposed::dequantize_per_tensor', 'quantized_decomposed::quantize_per_tensor'} #8369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ChristophKarlHeck opened this issue Feb 11, 2025 · 10 comments
Assignees
Labels
module: examples Issues related to demos under examples/ module: xnnpack Issues related to xnnpack delegation and the code under backends/xnnpack/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ChristophKarlHeck
Copy link

ChristophKarlHeck commented Feb 11, 2025

🐛 Describe the bug

Hi,
The following code throws the error mentioned in the tittle.

import torch
from torch.export import export, export_for_training, ExportedProgram
from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)
import executorch.exir as exir


class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)


example_args = (torch.randn(3, 4),)
pre_autograd_aten_dialect = export_for_training(M(), example_args).module()
# Optionally do quantization:
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, quantizer))
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)
# Optionally do delegation:
# edge_program = edge_program.to_backend(CustomBackendPartitioner)
executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch(
    ExecutorchBackendConfig(
        passes=[],  # User-defined passes
    )
)

with open("model.pte", "wb") as file:
    file.write(executorch_program.buffer)

Am I doing anything wrong?
Cheers,
Christoph

Versions

PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.0
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-122-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.5.119
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 950M
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        39 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               8
On-line CPU(s) list:                  0-7
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz
CPU family:                           6
Model:                                94
Thread(s) per core:                   2
Core(s) per socket:                   4
Socket(s):                            1
Stepping:                             3
CPU max MHz:                          3500,0000
CPU min MHz:                          800,0000
BogoMIPS:                             5199.98
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp md_clear flush_l1d arch_capabilities
Virtualization:                       VT-x
L1d cache:                            128 KiB (4 instances)
L1i cache:                            128 KiB (4 instances)
L2 cache:                             1 MiB (4 instances)
L3 cache:                             6 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-7
Vulnerability Gather data sampling:   Vulnerable: No microcode
Vulnerability Itlb multihit:          KVM: Mitigation: VMX disabled
Vulnerability L1tf:                   Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds:                    Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown:               Mitigation; PTI
Vulnerability Mmio stale data:        Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Mitigation; IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; IBRS; IBPB conditional; STIBP conditional; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Mitigation; Microcode
Vulnerability Tsx async abort:        Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] executorch==0.5.0
[pip3] numpy==2.0.0
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pytorch-lightning==2.5.0.post0
[pip3] torch==2.6.0
[pip3] torchaudio==2.6.0
[pip3] torchmetrics==1.6.1
[pip3] torchview==0.2.6
[pip3] torchvision==0.21.0
[pip3] triton==3.2.0
[conda] Could not collect
``

cc @digantdesai @mcr229
@jackzhxng
Copy link
Contributor

Here are some related threads:

I think in your case you are quantizing with XNNPack, maybe try forcing model.to(torch.float32) before your export.

@jackzhxng jackzhxng added module: examples Issues related to demos under examples/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 11, 2025
@jackzhxng jackzhxng self-assigned this Feb 11, 2025
@mergennachin mergennachin added the module: xnnpack Issues related to xnnpack delegation and the code under backends/xnnpack/ label Feb 11, 2025
@ChristophKarlHeck
Copy link
Author

ChristophKarlHeck commented Feb 12, 2025

@jackzhxng
Thank you for the reply. But the forcing model.to(torch.float32) doesn't work. Since I use the STM32WB55RG and there is no XNNPACK support for this kind of hardware, I need to find a way to quantize the model without XNNPACK. Unfortunately, the source-based quantizations throw the following error, too:

RuntimeError: Missing out variants: {'quantized_decomposed::dequantize_per_token', 'quantized_decomposed::choose_qparams_per_token_asymmetric', 'quantized_decomposed::dequantize_per_channel_group', 'quantized_decomposed::quantize_per_token'}

This is the script if you want to reproduce:

from torch.export import export, export_for_training, ExportedProgram
from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
import torch.ao.quantization.quantizer as quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
import executorch.exir as exir

import executorch.exir.passes as passes

print(dir(passes))

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5, bias=False)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

example_args = (torch.randn(3, 4),)
model = M()

print("Before quantization:")
for name, param in model.named_parameters():
    print(f"{name}: dtype={param.dtype}, shape={param.shape}")

model = Int8DynActInt4WeightQuantizer(precision=torch.int8, groupsize=4).quantize(model)

print("\nAfter quantization:")
for name, param in model.named_parameters():
    print(f"{name}: dtype={param.dtype}, shape={param.shape}")

# Export the Quantized Model
pre_autograd_aten_dialect = export_for_training(model, example_args).module()

# Ensure Executorch Recognizes Quantized Weights
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)

# Enable Quantization Optimization Passes
executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch(
    ExecutorchBackendConfig(
        passes=[],  # Enable Executorch Quantization Passes
    )
)

# Save Quantized Executorch Model
with open("model.pte", "wb") as file:
    file.write(executorch_program.buffer)

@jackzhxng
Copy link
Contributor

Try adding import executorch.kernels.quantized?

@ChristophKarlHeck
Copy link
Author

ChristophKarlHeck commented Feb 13, 2025

Doesn't work:

Before quantization:
param: dtype=torch.float32, shape=torch.Size([3, 4])
linear.weight: dtype=torch.float32, shape=torch.Size([5, 4])

After quantization:
param: dtype=torch.float32, shape=torch.Size([3, 4])
Traceback (most recent call last):
  File "/home/chris/watchplant_classification_dl/pipline_test/test_quan.py", line 45, in <module>
    executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch(
  File "/home/chris/watchplant_classification_dl/.venv/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 93, in wrapper
    return func(self, *args, **kwargs)
  File "/home/chris/watchplant_classification_dl/.venv/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 1364, in to_executorch
    new_gm_res = p(new_gm)
  File "/home/chris/watchplant_classification_dl/.venv/lib/python3.10/site-packages/torch/fx/passes/infra/pass_base.py", line 44, in __call__
    res = self.call(graph_module)
  File "/home/chris/watchplant_classification_dl/.venv/lib/python3.10/site-packages/executorch/exir/passes/__init__.py", line 427, in call
    raise RuntimeError(f"Missing out variants: {missing_out_vars}")
RuntimeError: Missing out variants: {'quantized_decomposed::quantize_per_token', 'quantized_decomposed::dequantize_per_channel_group', 'quantized_decomposed::dequantize_per_token', 'quantized_decomposed::choose_qparams_per_token_asymmetric'}

@mcr229
Copy link
Contributor

mcr229 commented Feb 13, 2025

Hi @ChristophKarlHeck, it looks like you're attempting to lower to your own custom backend. I think for now ExecuTorch might not have the implementations for those quantized variants.

The generally these ops should be consumed by your backend in order to run the model:

Image

as the example above, in order to run the quantized linear you match agains the quantized pattern shown above. For now the only way these are run is through XNNPACK delegating and recognizing this pattern as a quantized linear pattern.

@ChristophKarlHeck
Copy link
Author

Hi @mcr229,
Thank you for the explanation. My goal is to minimize the model_pte.h that. From your explanation, I understand that we need to be able to run XNNPACK on our target hardware, which we aren't, so post-training quantization doesn't work. I hope that we can do a pre-training quantization. So, int8 only instead of float32. Do you see any other approach to achieve this without using a backend on the target hardware, just the non-quantized aten operators?

I appreciate any help you can provide.

@mcr229
Copy link
Contributor

mcr229 commented Feb 14, 2025

For: STM32WB55RG, we don't have arm m-4 support just yet. It should only be supported for portable ops, but we don't have the quantized portable ops for the reasons above.

@digantdesai actually has plans for m4 support so perhaps he can shed some light on any workarounds

@digantdesai
Copy link
Contributor

digantdesai commented Feb 14, 2025

  • Yes, no XNNPACK on Cortex-M
  • Portable library should work on Cortex-M, you have to do dtype-selective-build to make sure size doesn't blow up. Think of this as a reference library, so perf can be bad. We are working on adding CMSIS-{nn, dsp} like high perf operators targeting Cortex-M soon. ETA-may be a year?
  • For quant ops,
    • If you really want to run them, these impl should compile on M4? I have to double check, but I would try these first?
    • If your input/output at runtime can be int8, which is I think what you want, i.e. e2e int8 network then you can use quantize-io-pass pass to get rid of first q-node and last dq-node.

@ChristophKarlHeck
Copy link
Author

@digantdesai Thank you!
I will try it if I don't get along with the small CNNs running on m4 without quantization.

@digantdesai
Copy link
Contributor

Closing this, feel free to reopen if you run into something similar. Good luck!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: examples Issues related to demos under examples/ module: xnnpack Issues related to xnnpack delegation and the code under backends/xnnpack/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants