Skip to content

Duplicate pt2e quant code from pytorch/pytorch to pytorch/ao #1916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
317 changes: 317 additions & 0 deletions test/quantization/pt2e_flow/test_duplicate_dq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
# Owner(s): ["oncall: quantization"]
# ruff: noqa: F841
import copy
import unittest
from typing import Any

import torch
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS

from torchao.quantization.pt2e_flow.observer import (
HistogramObserver,
MinMaxObserver,
PlaceholderObserver,
)
from torchao.quantization.pt2e_flow.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e_flow.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
Quantizer,
SharedQuantizationSpec,
)
from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer_utils import (
OP_TO_ANNOTATOR,
QuantizationConfig,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if TORCH_VERSION_AT_LEAST_2_5:
from torch.export import export_for_training


class TestHelperModules:
class Conv2dWithObsSharingOps(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.hardtanh = torch.nn.Hardtanh()
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
x = self.hardtanh(x)
x = x.view(-1, 3)
x = self.linear(x)
return x

class Conv2dWithSharedDQ(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 1)
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
x = self.conv1(x)
z = x.view(-1, 3)
w = self.linear(z)

y = self.conv2(x)
add_output = x + y

extra_output = x * 2
return w, add_output, extra_output

class ModuleForDifferentQconfig(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 1)
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))

def forward(self, x):
x = self.conv1(x)
w = self.adaptive_avg_pool2d(x)

y = self.conv2(x)
add_output = x + y

extra_output = x + 2
return w, add_output, extra_output


_DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]


@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+")
class TestDuplicateDQPass(QuantizationTestCase):
def _test_duplicate_dq(
self,
model,
example_inputs,
quantizer,
):
m_eager = model.eval()

# program capture
m = copy.deepcopy(m_eager)
m = export_for_training(
m,
example_inputs,
).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)

pt2_quant_output = m(*example_inputs)
for n in m.graph.nodes:
annotation = n.meta.get("quantization_annotation", None)
if annotation is not None:
for arg in n.args:
if isinstance(arg, torch.fx.Node) and arg.target in _DEQUANTIZE_OPS:
self.assertEqual(len(arg.users.keys()), 1)

def test_no_need_for_duplicate_dq(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Check quantization tags on conv2d, avgpool and linear are correctly set
"""

class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

example_inputs = (torch.randn(1, 3, 5, 7),)
self._test_duplicate_dq(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
)

def test_simple_duplicate_dq(self):
"""
Model under test
conv2d -> conv2d -> add
| |
--------->
|
-----> view_copy --> linear
|
-----> mul
There should be three dq nodes because output for the
first conv2d is fed to next conv2d, add, and view_copy + linear.
All three are quantized.
Thus DQ node is not duplicated for those three uses
"""

class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["add"](gm, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

example_inputs = (torch.randn(1, 3, 5, 7),)
self._test_duplicate_dq(
TestHelperModules.Conv2dWithSharedDQ(),
example_inputs,
BackendAQuantizer(),
)

def test_no_add_quant_duplicate_dq(self):
"""
Model under test
conv2d -> conv2d -> add
| |
--------->
|
-----> view_copy --> linear
|
-----> mul
There should be three dq nodes because output for the
first conv2d is fed to next conv2d, and view_copy + linear.
Both are quantized.
However the skip connection to add and mul are not quantized.
Thus DQ node is not duplicated for those two uses
"""

class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

example_inputs = (torch.randn(1, 3, 5, 7),)
self._test_duplicate_dq(
TestHelperModules.Conv2dWithSharedDQ(),
example_inputs,
BackendAQuantizer(),
)

def test_avgpool_use_different_qconfig(self):
"""
Model under test
conv2d -> conv2d -> add
| |
--------->
|
-----> adaptive_avgpool2d (different qconfig)
|
-----> add
output
conv2d -> dq -> conv2d -> add
| |
-------> dq ----->
|
-> dq -> q -> dq -----> adaptive_avgpool2d (different qconfig)
|
-> dq -----> add
"""

def _get_uint8_quantization_config():
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
act_quantization_spec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
eps=2**-12
),
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821
MinMaxObserver
)

extra_args: dict[str, Any] = {"eps": 2**-12}
weight_quantization_spec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)

bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821
PlaceholderObserver
)
bias_quantization_spec = QuantizationSpec(
dtype=torch.float,
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr,
)
quantization_config = QuantizationConfig(
act_quantization_spec,
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec,
)
return quantization_config

class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
avgpool_qconfig = _get_uint8_quantization_config()
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["add"](gm, quantization_config)
for n in gm.graph.nodes:
if n.op == "call_function" and n.target == torch.ops.aten.mean.dim:
qspec = avgpool_qconfig.input_activation
input_act = n.args[0]
output_qspec = SharedQuantizationSpec((input_act, n))
n.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={input_act: qspec},
output_qspec=output_qspec,
_annotated=True,
)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

example_inputs = (torch.randn(1, 3, 5, 7),)
self._test_duplicate_dq(
TestHelperModules.ModuleForDifferentQconfig(),
example_inputs,
BackendAQuantizer(),
)
Loading
Loading