From f0dd295b6d7a31471a412dd8cfd974fa4676a246 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 17 Mar 2025 17:08:11 -0700 Subject: [PATCH 1/6] Migrate pt2e quant code from pytorch/pytorch to pytorch/ao Summary: First step of https://dev-discuss.pytorch.org/t/torch-ao-quantization-migration-plan/2810 core logic of pt2e are duplicated, and also ported test here Next: move meta internal callsites to depend on torchao Test Plan: pytest test/quantization/pt2e_flow Reviewers: Subscribers: Tasks: Tags: --- .../pt2e_flow/test_duplicate_dq.py | 313 ++ .../pt2e_flow/test_graph_utils.py | 124 + .../pt2e_flow/test_metadata_porting.py | 521 ++++ .../pt2e_flow/test_numeric_debugger.py | 350 +++ .../pt2e_flow/test_quantize_pt2e.py | 2598 ++++++++++++++++ .../pt2e_flow/test_quantize_pt2e_qat.py | 1161 +++++++ .../pt2e_flow/test_representation.py | 314 ++ .../pt2e_flow/test_x86inductor_quantizer.py | 2737 +++++++++++++++++ .../pt2e_flow/test_xnnpack_quantizer.py | 1092 +++++++ torchao/quantization/pt2e_flow/__init__.py | 166 + .../quantization/pt2e_flow/fake_quantize.py | 650 ++++ torchao/quantization/pt2e_flow/observer.py | 2050 ++++++++++++ .../pt2e_flow/pt2e/_affine_quantization.py | 775 +++++ .../pt2e_flow/pt2e/_numeric_debugger.py | 342 ++ .../quantization/pt2e_flow/pt2e/convert.py | 1400 +++++++++ .../pt2e_flow/pt2e/duplicate_dq_pass.py | 82 + .../pt2e_flow/pt2e/export_utils.py | 240 ++ .../pt2e_flow/pt2e/graph_utils.py | 180 ++ .../pt2e_flow/pt2e/port_metadata_pass.py | 215 ++ .../quantization/pt2e_flow/pt2e/prepare.py | 662 ++++ .../quantization/pt2e_flow/pt2e/qat_utils.py | 991 ++++++ .../pt2e_flow/pt2e/representation/__init__.py | 5 + .../pt2e_flow/pt2e/representation/rewrite.py | 819 +++++ torchao/quantization/pt2e_flow/pt2e/utils.py | 610 ++++ torchao/quantization/pt2e_flow/qconfig.py | 699 +++++ .../quantization/pt2e_flow/quantize_pt2e.py | 266 ++ .../pt2e_flow/quantizer/__init__.py | 21 + .../quantizer/composable_quantizer.py | 78 + .../quantizer/embedding_quantizer.py | 97 + .../pt2e_flow/quantizer/quantizer.py | 180 ++ .../quantization/pt2e_flow/quantizer/utils.py | 83 + .../quantizer/x86_inductor_quantizer.py | 1572 ++++++++++ .../pt2e_flow/quantizer/xnnpack_quantizer.py | 447 +++ .../quantizer/xnnpack_quantizer_utils.py | 1127 +++++++ .../quantizer/xpu_inductor_quantizer.py | 125 + torchao/quantization/pt2e_flow/utils.py | 822 +++++ torchao/quantization/quant_primitives.py | 2 +- torchao/testing/pt2e/utils.py | 158 + 38 files changed, 24073 insertions(+), 1 deletion(-) create mode 100644 test/quantization/pt2e_flow/test_duplicate_dq.py create mode 100644 test/quantization/pt2e_flow/test_graph_utils.py create mode 100644 test/quantization/pt2e_flow/test_metadata_porting.py create mode 100644 test/quantization/pt2e_flow/test_numeric_debugger.py create mode 100644 test/quantization/pt2e_flow/test_quantize_pt2e.py create mode 100644 test/quantization/pt2e_flow/test_quantize_pt2e_qat.py create mode 100644 test/quantization/pt2e_flow/test_representation.py create mode 100644 test/quantization/pt2e_flow/test_x86inductor_quantizer.py create mode 100644 test/quantization/pt2e_flow/test_xnnpack_quantizer.py create mode 100644 torchao/quantization/pt2e_flow/__init__.py create mode 100644 torchao/quantization/pt2e_flow/fake_quantize.py create mode 100644 torchao/quantization/pt2e_flow/observer.py create mode 100644 torchao/quantization/pt2e_flow/pt2e/_affine_quantization.py create mode 100644 torchao/quantization/pt2e_flow/pt2e/_numeric_debugger.py create mode 100644 torchao/quantization/pt2e_flow/pt2e/convert.py create mode 100644 torchao/quantization/pt2e_flow/pt2e/duplicate_dq_pass.py create mode 100644 torchao/quantization/pt2e_flow/pt2e/export_utils.py create mode 100644 torchao/quantization/pt2e_flow/pt2e/graph_utils.py create mode 100644 torchao/quantization/pt2e_flow/pt2e/port_metadata_pass.py create mode 100644 torchao/quantization/pt2e_flow/pt2e/prepare.py create mode 100644 torchao/quantization/pt2e_flow/pt2e/qat_utils.py create mode 100644 torchao/quantization/pt2e_flow/pt2e/representation/__init__.py create mode 100644 torchao/quantization/pt2e_flow/pt2e/representation/rewrite.py create mode 100644 torchao/quantization/pt2e_flow/pt2e/utils.py create mode 100644 torchao/quantization/pt2e_flow/qconfig.py create mode 100644 torchao/quantization/pt2e_flow/quantize_pt2e.py create mode 100644 torchao/quantization/pt2e_flow/quantizer/__init__.py create mode 100644 torchao/quantization/pt2e_flow/quantizer/composable_quantizer.py create mode 100644 torchao/quantization/pt2e_flow/quantizer/embedding_quantizer.py create mode 100644 torchao/quantization/pt2e_flow/quantizer/quantizer.py create mode 100644 torchao/quantization/pt2e_flow/quantizer/utils.py create mode 100644 torchao/quantization/pt2e_flow/quantizer/x86_inductor_quantizer.py create mode 100644 torchao/quantization/pt2e_flow/quantizer/xnnpack_quantizer.py create mode 100644 torchao/quantization/pt2e_flow/quantizer/xnnpack_quantizer_utils.py create mode 100644 torchao/quantization/pt2e_flow/quantizer/xpu_inductor_quantizer.py create mode 100644 torchao/quantization/pt2e_flow/utils.py create mode 100644 torchao/testing/pt2e/utils.py diff --git a/test/quantization/pt2e_flow/test_duplicate_dq.py b/test/quantization/pt2e_flow/test_duplicate_dq.py new file mode 100644 index 0000000000..faec98e589 --- /dev/null +++ b/test/quantization/pt2e_flow/test_duplicate_dq.py @@ -0,0 +1,313 @@ +# Owner(s): ["oncall: quantization"] +# ruff: noqa: F841 +import copy +import unittest +from typing import Any + +import torch +from torch.export import export_for_training +from torch.testing._internal.common_quantization import QuantizationTestCase +from torch.testing._internal.common_utils import IS_WINDOWS + +from torchao.quantization.pt2e_flow.observer import ( + HistogramObserver, + MinMaxObserver, + PlaceholderObserver, +) +from torchao.quantization.pt2e_flow.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e_flow.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, + SharedQuantizationSpec, +) +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, +) +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer_utils import ( + OP_TO_ANNOTATOR, + QuantizationConfig, +) + + +class TestHelperModules: + class Conv2dWithObsSharingOps(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.hardtanh = torch.nn.Hardtanh() + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + x = self.conv(x) + x = self.adaptive_avg_pool2d(x) + x = self.hardtanh(x) + x = x.view(-1, 3) + x = self.linear(x) + return x + + class Conv2dWithSharedDQ(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 1) + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + x = self.conv1(x) + z = x.view(-1, 3) + w = self.linear(z) + + y = self.conv2(x) + add_output = x + y + + extra_output = x * 2 + return w, add_output, extra_output + + class ModuleForDifferentQconfig(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 1) + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + x = self.conv1(x) + w = self.adaptive_avg_pool2d(x) + + y = self.conv2(x) + add_output = x + y + + extra_output = x + 2 + return w, add_output, extra_output + + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") +class TestDuplicateDQPass(QuantizationTestCase): + def _test_duplicate_dq( + self, + model, + example_inputs, + quantizer, + ): + m_eager = model.eval() + + # program capture + m = copy.deepcopy(m_eager) + m = export_for_training( + m, + example_inputs, + ).module() + + m = prepare_pt2e(m, quantizer) + # Calibrate + m(*example_inputs) + m = convert_pt2e(m) + + pt2_quant_output = m(*example_inputs) + for n in m.graph.nodes: + annotation = n.meta.get("quantization_annotation", None) + if annotation is not None: + for arg in n.args: + if isinstance(arg, torch.fx.Node) and arg.target in _DEQUANTIZE_OPS: + self.assertEqual(len(arg.users.keys()), 1) + + def test_no_need_for_duplicate_dq(self): + """ + Model under test + conv2d -> avgpool -> hardtanh -> linear + Check quantization tags on conv2d, avgpool and linear are correctly set + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + quantization_config = get_symmetric_quantization_config( + is_per_channel=True + ) + OP_TO_ANNOTATOR["linear"](gm, quantization_config) + OP_TO_ANNOTATOR["conv"](gm, quantization_config) + OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 7),) + self._test_duplicate_dq( + TestHelperModules.Conv2dWithObsSharingOps(), + example_inputs, + BackendAQuantizer(), + ) + + def test_simple_duplicate_dq(self): + """ + Model under test + conv2d -> conv2d -> add + | | + ---------> + | + -----> view_copy --> linear + | + -----> mul + There should be three dq nodes because output for the + first conv2d is fed to next conv2d, add, and view_copy + linear. + All three are quantized. + Thus DQ node is not duplicated for those three uses + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + quantization_config = get_symmetric_quantization_config( + is_per_channel=True + ) + OP_TO_ANNOTATOR["linear"](gm, quantization_config) + OP_TO_ANNOTATOR["conv"](gm, quantization_config) + OP_TO_ANNOTATOR["add"](gm, quantization_config) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 7),) + self._test_duplicate_dq( + TestHelperModules.Conv2dWithSharedDQ(), + example_inputs, + BackendAQuantizer(), + ) + + def test_no_add_quant_duplicate_dq(self): + """ + Model under test + conv2d -> conv2d -> add + | | + ---------> + | + -----> view_copy --> linear + | + -----> mul + There should be three dq nodes because output for the + first conv2d is fed to next conv2d, and view_copy + linear. + Both are quantized. + However the skip connection to add and mul are not quantized. + Thus DQ node is not duplicated for those two uses + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + quantization_config = get_symmetric_quantization_config( + is_per_channel=True + ) + OP_TO_ANNOTATOR["linear"](gm, quantization_config) + OP_TO_ANNOTATOR["conv"](gm, quantization_config) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 7),) + self._test_duplicate_dq( + TestHelperModules.Conv2dWithSharedDQ(), + example_inputs, + BackendAQuantizer(), + ) + + def test_avgpool_use_different_qconfig(self): + """ + Model under test + conv2d -> conv2d -> add + | | + ---------> + | + -----> adaptive_avgpool2d (different qconfig) + | + -----> add + output + conv2d -> dq -> conv2d -> add + | | + -------> dq -----> + | + -> dq -> q -> dq -----> adaptive_avgpool2d (different qconfig) + | + -> dq -----> add + """ + + def _get_uint8_quantization_config(): + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + eps=2**-12 + ), + ) + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821 + MinMaxObserver + ) + + extra_args: dict[str, Any] = {"eps": 2**-12} + weight_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821 + PlaceholderObserver + ) + bias_quantization_spec = QuantizationSpec( + dtype=torch.float, + observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr, + ) + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + ) + return quantization_config + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + quantization_config = get_symmetric_quantization_config( + is_per_channel=True + ) + avgpool_qconfig = _get_uint8_quantization_config() + OP_TO_ANNOTATOR["conv"](gm, quantization_config) + OP_TO_ANNOTATOR["add"](gm, quantization_config) + for n in gm.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.mean.dim: + qspec = avgpool_qconfig.input_activation + input_act = n.args[0] + output_qspec = SharedQuantizationSpec((input_act, n)) + n.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={input_act: qspec}, + output_qspec=output_qspec, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 7),) + self._test_duplicate_dq( + TestHelperModules.ModuleForDifferentQconfig(), + example_inputs, + BackendAQuantizer(), + ) diff --git a/test/quantization/pt2e_flow/test_graph_utils.py b/test/quantization/pt2e_flow/test_graph_utils.py new file mode 100644 index 0000000000..42ac3f244f --- /dev/null +++ b/test/quantization/pt2e_flow/test_graph_utils.py @@ -0,0 +1,124 @@ +# Owner(s): ["oncall: quantization"] +import copy +import unittest + +import torch +import torch._dynamo as torchdynamo +from torch.testing._internal.common_utils import IS_WINDOWS, TestCase + +from torchao.quantization.pt2e_flow.pt2e.graph_utils import ( + find_sequential_partitions, + get_equivalent_types, + update_equivalent_types_dict, +) + + +class TestGraphUtils(TestCase): + @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") + def test_conv_bn_conv_relu(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.bn1 = torch.nn.BatchNorm2d(3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + self.relu2 = torch.nn.ReLU() + + def forward(self, x): + bn_out = self.bn1(self.conv1(x)) + relu_out = torch.nn.functional.relu(bn_out) + return self.relu2(self.conv2(relu_out)) + + m = M().eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + + # program capture + m, guards = torchdynamo.export( # noqa: F841© + m, + *copy.deepcopy(example_inputs), + aten_graph=True, + ) + fused_partitions = find_sequential_partitions( + m, [torch.nn.Conv2d, torch.nn.BatchNorm2d] + ) + self.assertEqual(len(fused_partitions), 1) + fused_partitions = find_sequential_partitions( + m, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU] + ) + self.assertEqual(len(fused_partitions), 1) + + def x(): + find_sequential_partitions( + m, + [ + torch.nn.Conv2d, + torch.nn.BatchNorm2d, + torch.nn.ReLU, + torch.nn.functional.conv2d, + ], + ) + + self.assertRaises(ValueError, x) + + @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") + def test_conv_bn_relu(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.bn1 = torch.nn.BatchNorm2d(3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + self.relu2 = torch.nn.ReLU() + + def forward(self, x): + bn_out = self.bn1(x) + return self.relu2(self.conv2(bn_out)) + + m = M().eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + + # program capture + m, guards = torchdynamo.export( # noqa: F841 + m, + *copy.deepcopy(example_inputs), + aten_graph=True, + ) + fused_partitions = find_sequential_partitions( + m, [torch.nn.Conv2d, torch.nn.BatchNorm2d] + ) + self.assertEqual(len(fused_partitions), 0) + fused_partitions = find_sequential_partitions( + m, [torch.nn.BatchNorm2d, torch.nn.Conv2d] + ) + self.assertEqual(len(fused_partitions), 1) + fused_partitions = find_sequential_partitions( + m, [torch.nn.BatchNorm2d, torch.nn.ReLU] + ) + self.assertEqual(len(fused_partitions), 0) + + @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") + def test_customized_equivalet_types_dict(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + return torch.nn.functional.relu6(self.conv(x)) + + m = M().eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + + # program capture + m, guards = torchdynamo.export( # noqa: F841 + m, + *copy.deepcopy(example_inputs), + aten_graph=True, + ) + customized_equivalent_types = get_equivalent_types() + customized_equivalent_types.append({torch.nn.ReLU6, torch.nn.functional.relu6}) + update_equivalent_types_dict(customized_equivalent_types) + fused_partitions = find_sequential_partitions( + m, + [torch.nn.Conv2d, torch.nn.ReLU6], + ) + self.assertEqual(len(fused_partitions), 1) diff --git a/test/quantization/pt2e_flow/test_metadata_porting.py b/test/quantization/pt2e_flow/test_metadata_porting.py new file mode 100644 index 0000000000..c2655d114b --- /dev/null +++ b/test/quantization/pt2e_flow/test_metadata_porting.py @@ -0,0 +1,521 @@ +# Owner(s): ["oncall: quantization"] +import copy +import unittest + +import torch +import torch._export +from torch.fx import Node +from torch.testing._internal.common_quantization import QuantizationTestCase +from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef + +from torchao.quantization.pt2e_flow.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e_flow.quantizer import QuantizationAnnotation, Quantizer +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, +) +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer_utils import ( + OP_TO_ANNOTATOR, +) + + +class TestHelperModules: + class Conv2dWithObsSharingOps(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.hardtanh = torch.nn.Hardtanh() + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + x = self.conv(x) + x = self.adaptive_avg_pool2d(x) + x = self.hardtanh(x) + x = x.view(-1, 3) + x = self.linear(x) + return x + + +def _tag_partitions( + backend_name: str, op_name: str, annotated_partitions: list[list[Node]] +): + for index, partition_nodes in enumerate(annotated_partitions): + tag_name = backend_name + "_" + op_name + "_" + str(index) + for node in partition_nodes: + assert "quantization_tag" not in node.meta, f"{node} is already tagged" + node.meta["quantization_tag"] = tag_name + + +_QUANT_OPS = { + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.choose_qparams.tensor, +} + + +# TODO: rename to TestPortMetadataPass to align with the util name? +@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") +class TestMetaDataPorting(QuantizationTestCase): + def _test_quant_tag_preservation_through_decomp( + self, model, example_inputs, from_node_to_tags + ): + ep = torch.export.export(model, example_inputs, strict=True) + found_tags = True + not_found_nodes = "" + for from_node, tag in from_node_to_tags.items(): + for n in ep.graph_module.graph.nodes: + from_node_meta = n.meta.get("from_node", None) + if from_node_meta is None: + continue + if not isinstance(from_node_meta, list): + raise ValueError( + f"from_node metadata is of type {type(from_node_meta)}, but expected list" + ) + for meta in from_node_meta: + node_target = meta.target + if node_target == str(from_node): + node_tag = n.meta.get("quantization_tag", None) + if node_tag is None or tag != node_tag: + not_found_nodes += str(n.target) + ", " + found_tags = False + break + if not found_tags: + break + self.assertTrue( + found_tags, + f"Decomposition did not preserve quantization tag for {not_found_nodes}", + ) + + def _test_metadata_porting( + self, + model, + example_inputs, + quantizer, + node_tags=None, + ) -> torch.fx.GraphModule: + m_eager = model.eval() + + # program capture + m = copy.deepcopy(m_eager) + m = torch.export.export_for_training( + m, + example_inputs, + ).module() + + m = prepare_pt2e(m, quantizer) + # Calibrate + m(*example_inputs) + m = convert_pt2e(m) + + m(*example_inputs) + recorded_node_tags = {} + for n in m.graph.nodes: + if "quantization_tag" not in n.meta: + continue + if n.op == "call_function" and n.target in _QUANT_OPS: + key = n.target + elif n.op == "get_attr": + key = "get_attr" + else: + continue + + if key not in recorded_node_tags: + recorded_node_tags[key] = set() + + if ( + n.op == "call_function" + and n.meta["quantization_tag"] in recorded_node_tags[key] + ): + raise ValueError( + f"{key} {n.format_node()} has tag {n.meta['quantization_tag']} that " + "is associated with another node of the same type" + ) + recorded_node_tags[key].add(n.meta["quantization_tag"]) + + self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys())) + for k, v in recorded_node_tags.items(): + self.assertEqual(v, node_tags[k]) + return m + + @skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack + # trace of the mode torch function impl doesn't match the traced graph stored lineno. + def test_simple_metadata_porting(self): + """ + Model under test + conv2d -> avgpool -> hardtanh -> linear + Check quantization tags on conv2d, avgpool and linear are correctly set + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + quantization_config = get_symmetric_quantization_config( + is_per_channel=True + ) + annotated_partitions = OP_TO_ANNOTATOR["linear"]( + gm, quantization_config + ) + _tag_partitions(backend_string, "linear", annotated_partitions) + annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config) + _tag_partitions(backend_string, "conv2d", annotated_partitions) + annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"]( + gm, quantization_config + ) + _tag_partitions( + backend_string, "adaptive_avg_pool2d", annotated_partitions + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + get_attr_tags = { + "BackendA_conv2d_0", + "BackendA_linear_0", + } + quantize_per_tensor_tags = { + "BackendA_conv2d_0", + "BackendA_adaptive_avg_pool2d_0", + "BackendA_linear_0", + } + dequantize_per_tensor_tags = { + "BackendA_adaptive_avg_pool2d_0", + "BackendA_conv2d_0", + "BackendA_linear_0", + } + dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} + node_tags = { + "get_attr": get_attr_tags, + torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, + } + m = self._test_metadata_porting( + TestHelperModules.Conv2dWithObsSharingOps(), + example_inputs, + BackendAQuantizer(), + node_tags, + ) + + from_node_to_tags = { + torch.ops.aten.adaptive_avg_pool2d.default: "BackendA_adaptive_avg_pool2d_0", + torch.ops.aten.linear.default: "BackendA_linear_0", + } + self._test_quant_tag_preservation_through_decomp( + m, example_inputs, from_node_to_tags + ) + + def test_metadata_porting_with_no_quant_inbetween(self): + """ + Model under test + conv2d -> avgpool -> hardtanh -> linear + Dont quantize avgpool + Check quantization tags on conv2d and linear are correctly set + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + quantization_config = get_symmetric_quantization_config( + is_per_channel=True + ) + annotated_partitions = OP_TO_ANNOTATOR["linear"]( + gm, quantization_config + ) + _tag_partitions(backend_string, "linear", annotated_partitions) + annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config) + _tag_partitions(backend_string, "conv2d", annotated_partitions) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + get_attr_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} + quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} + dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} + dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} + node_tags = { + "get_attr": get_attr_tags, + torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, + } + self._test_metadata_porting( + TestHelperModules.Conv2dWithObsSharingOps(), + example_inputs, + BackendAQuantizer(), + node_tags, + ) + + @unittest.skip("Temporarily disabled") + def test_metadata_porting_for_dq(self): + """ + Model under test + conv2d -> avgpool -> hardtanh -> linear + Quantize all except linear. + Quantize linear with dynamic quantization + Check quantization tags on conv2d, avgpool and linear are correctly set + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + # static quantiazation + quantization_config = get_symmetric_quantization_config( + is_per_channel=True + ) + annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config) + _tag_partitions(backend_string, "conv2d", annotated_partitions) + annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"]( + gm, quantization_config + ) + _tag_partitions( + backend_string, "adaptive_avg_pool2d", annotated_partitions + ) + + # dynamic quantization + quantization_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + annotated_partitions = OP_TO_ANNOTATOR["linear"]( + gm, quantization_config_dynamic + ) + _tag_partitions(backend_string, "linear_dynamic", annotated_partitions) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + # TODO: add get_attr_tags when the test is re-enabled + get_attr_tags = {} + quantize_per_tensor_tags = { + "BackendA_conv2d_0", + "BackendA_adaptive_avg_pool2d_0", + } + quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} + choose_qparams_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} + dequantize_per_tensor_tags = { + "BackendA_adaptive_avg_pool2d_0", + "BackendA_conv2d_0", + } + dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} + dequantize_per_channel_tags = { + "BackendA_conv2d_0", + "BackendA_linear_dynamic_0", + } + node_tags = { + "get_attr": get_attr_tags, + torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, + torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tensor_tags, + } + self._test_metadata_porting( + TestHelperModules.Conv2dWithObsSharingOps(), + example_inputs, + BackendAQuantizer(), + node_tags, + ) + + def test_metadata_porting_for_two_dq(self): + """ + Model under test + conv2d -> avgpool -> hardtanh -> linear + Quantize linear and conv with dynamic quantization + Check quantization tags on conv2d, avgpool and linear are correctly set + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + + # dynamic quantization + quantization_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + annotated_partitions = OP_TO_ANNOTATOR["conv"]( + gm, quantization_config_dynamic + ) + _tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions) + annotated_partitions = OP_TO_ANNOTATOR["linear"]( + gm, quantization_config_dynamic + ) + _tag_partitions(backend_string, "linear_dynamic", annotated_partitions) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + get_attr_tags = { + "BackendA_conv2d_dynamic_0", + "BackendA_linear_dynamic_0", + } + choose_qparams_tensor_tags = { + "BackendA_conv2d_dynamic_0", + "BackendA_linear_dynamic_0", + } + quantize_per_tensor_tensor_tags = { + "BackendA_conv2d_dynamic_0", + "BackendA_linear_dynamic_0", + } + dequantize_per_tensor_tensor_tags = { + "BackendA_conv2d_dynamic_0", + "BackendA_linear_dynamic_0", + } + dequantize_per_channel_tags = { + "BackendA_conv2d_dynamic_0", + "BackendA_linear_dynamic_0", + } + node_tags = { + "get_attr": get_attr_tags, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, + torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags, + } + self._test_metadata_porting( + TestHelperModules.Conv2dWithObsSharingOps(), + example_inputs, + BackendAQuantizer(), + node_tags, + ) + + def test_metadata_porting_for_dq_no_static_q(self): + """ + Model under test + conv2d -> avgpool -> hardtanh -> linear + Dont quantize anything except linear. + Quantize linear with dynamic quantization + Check quantization tags on conv2d, avgpool and linear are correctly set + """ + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + backend_string = "BackendA" + # dynamic quantization + quantization_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + annotated_partitions = OP_TO_ANNOTATOR["linear"]( + gm, quantization_config_dynamic + ) + _tag_partitions(backend_string, "linear_dynamic", annotated_partitions) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + get_attr_tags = {"BackendA_linear_dynamic_0"} + choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"} + quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} + dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} + dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"} + node_tags = { + "get_attr": get_attr_tags, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, + torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags, + } + self._test_metadata_porting( + TestHelperModules.Conv2dWithObsSharingOps(), + example_inputs, + BackendAQuantizer(), + node_tags, + ) + + def test_no_metadata_porting(self): + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + quantization_config = get_symmetric_quantization_config( + is_per_channel=True + ) + OP_TO_ANNOTATOR["linear"](gm, quantization_config) + OP_TO_ANNOTATOR["conv"](gm, quantization_config) + OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + node_tags = {} + m = self._test_metadata_porting( + TestHelperModules.Conv2dWithObsSharingOps(), + example_inputs, + BackendAQuantizer(), + node_tags, + ) + + from_node_to_tags = {} + self._test_quant_tag_preservation_through_decomp( + m, example_inputs, from_node_to_tags + ) + + def test_no_metadata_porting_through_unknown_ops(self): + """ + Model under test + matmul -> add -> relu + matmul has get_attr as first input, but the quantization_tag should not be + propagated to add even if it's part of a chain that ends at get_attr + """ + + class MatmulWithConstInput(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_parameter("w", torch.nn.Parameter(torch.rand(8, 16))) + + def forward(self, x, y): + x = torch.matmul(self.w, x) + z = x + y + return torch.nn.functional.relu(z) + + class BackendAQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + qconfig = get_symmetric_quantization_config() + for n in gm.graph.nodes: + if n.op != "call_function": + continue + + n.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={n.args[0]: qconfig.input_activation}, + output_qspec=qconfig.output_activation, + ) + + tag = str(n.target) + n.meta["quantization_tag"] = tag + for arg in n.args: + if arg.op == "get_attr": + arg.meta["quantization_tag"] = tag + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(16, 24), torch.randn(8, 24)) + get_attr_tags = {"aten.matmul.default"} + quantize_per_tensor_tensor_tags = { + "aten.matmul.default", + "aten.add.Tensor", + "aten.relu.default", + } + dequantize_per_tensor_tensor_tags = { + "aten.matmul.default", + "aten.add.Tensor", + "aten.relu.default", + } + node_tags = { + "get_attr": get_attr_tags, + torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tensor_tags, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tensor_tags, + } + self._test_metadata_porting( + MatmulWithConstInput(), + example_inputs, + BackendAQuantizer(), + node_tags, + ) diff --git a/test/quantization/pt2e_flow/test_numeric_debugger.py b/test/quantization/pt2e_flow/test_numeric_debugger.py new file mode 100644 index 0000000000..3c9a7783a7 --- /dev/null +++ b/test/quantization/pt2e_flow/test_numeric_debugger.py @@ -0,0 +1,350 @@ +# Owner(s): ["oncall: quantization"] + +import copy +import unittest +from collections import Counter + +import torch +from torch._dynamo.test_case import TestCase as TorchDynamoTestCase +from torch.export import export_for_training +from torch.testing._internal.common_quantization import TestHelperModules +from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef + +from torchao.quantization.pt2e_flow import ( + CUSTOM_KEY, + NUMERIC_DEBUG_HANDLE_KEY, + compare_results, + extract_results_from_loggers, + generate_numeric_debug_handle, + prepare_for_propagation_comparison, +) +from torchao.quantization.pt2e_flow.pt2e.graph_utils import bfs_trace_with_node_process +from torchao.quantization.pt2e_flow.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, +) + + +@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") +class TestNumericDebugger(TorchDynamoTestCase): + def _assert_each_node_has_debug_handle(self, model) -> None: + def _assert_node_has_debug_handle(node): + self.assertTrue( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY], + f"Node {node} doesn't have debug handle", + ) + + bfs_trace_with_node_process(model, _assert_node_has_debug_handle) + + def _extract_debug_handles(self, model) -> dict[str, int]: + debug_handle_map: dict[str, int] = {} + + def _extract_debug_handles_from_node(node): + nonlocal debug_handle_map + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ + NUMERIC_DEBUG_HANDLE_KEY + ] + + bfs_trace_with_node_process(model, _extract_debug_handles_from_node) + + return debug_handle_map + + def _extract_debug_handles_with_prev_decomp_op(self, model) -> dict[str, int]: + prev_decomp_op_to_debug_handle_map: dict[str, int] = {} + + def _extract_debug_handles_with_prev_decomp_op_from_node(node): + nonlocal prev_decomp_op_to_debug_handle_map + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + prev_decomp_op = str(node.meta.get("nn_module_stack")) + debug_handle = node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] + if prev_decomp_op not in prev_decomp_op_to_debug_handle_map: + prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle + else: + assert ( + prev_decomp_op_to_debug_handle_map[prev_decomp_op] + == debug_handle + ), f"Node {node} has different debug handle {debug_handle}" + "than previous node sharing the same decomp op {prev_decomp_op}" + + bfs_trace_with_node_process( + model, _extract_debug_handles_with_prev_decomp_op_from_node + ) + return prev_decomp_op_to_debug_handle_map + + def test_simple(self): + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + self._assert_each_node_has_debug_handle(ep) + debug_handle_map = self._extract_debug_handles(ep) + + self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map)) + + def test_control_flow(self): + m = TestHelperModules.ControlFlow() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + + self._assert_each_node_has_debug_handle(ep) + debug_handle_map = self._extract_debug_handles(ep) + + self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map)) + + def test_quantize_pt2e_preserve_handle(self): + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + m = ep.module() + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=False) + ) + m = prepare_pt2e(m, quantizer) + debug_handle_map = self._extract_debug_handles(m) + res_counter = Counter(debug_handle_map.values()) + repeated_debug_handle_ids = [1, 2, 3] + # 3 ids were repeated because we copy over the id from node to its output observer + # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default + for dh_id in repeated_debug_handle_ids: + self.assertEqual(res_counter[dh_id], 2) + + m(*example_inputs) + m = convert_pt2e(m) + self._assert_each_node_has_debug_handle(ep) + debug_handle_map = self._extract_debug_handles(m) + res_counter = Counter(debug_handle_map.values()) + # same set of ids where repeated, because we copy over the id from observer/fake_quant to + # dequantize node + repeated_debug_handle_ids = [1, 2, 3] + for dh_id in repeated_debug_handle_ids: + self.assertEqual(res_counter[dh_id], 2) + + def test_copy_preserve_handle(self): + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = torch.export.export(m, example_inputs, strict=True) + generate_numeric_debug_handle(ep) + + self._assert_each_node_has_debug_handle(ep) + debug_handle_map_ref = self._extract_debug_handles(ep) + + ep_copy = copy.copy(ep) + debug_handle_map = self._extract_debug_handles(ep_copy) + + self._assert_each_node_has_debug_handle(ep) + self.assertEqual(debug_handle_map, debug_handle_map_ref) + + def test_deepcopy_preserve_handle(self): + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = torch.export.export(m, example_inputs, strict=True) + generate_numeric_debug_handle(ep) + + debug_handle_map_ref = self._extract_debug_handles(ep) + ep_copy = copy.deepcopy(ep) + debug_handle_map = self._extract_debug_handles(ep_copy) + + self._assert_each_node_has_debug_handle(ep) + self.assertEqual(debug_handle_map, debug_handle_map_ref) + + @skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack + # trace of the mode torch function impl doesn't match the traced graph stored lineno. + def test_re_export_preserve_handle(self): + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + m = ep.module() + + self._assert_each_node_has_debug_handle(ep) + debug_handle_map_ref = self._extract_debug_handles(ep) + + ep_reexport = export_for_training(m, example_inputs) + + self._assert_each_node_has_debug_handle(ep_reexport) + debug_handle_map = self._extract_debug_handles(ep_reexport) + + self.assertEqual(debug_handle_map, debug_handle_map_ref) + + def test_run_decompositions_same_handle_id(self): + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + + self._assert_each_node_has_debug_handle(ep) + debug_handle_map_ref = self._extract_debug_handles(ep) + + ep_copy = copy.copy(ep) + ep_copy = ep_copy.run_decompositions() + + self._assert_each_node_has_debug_handle(ep_copy) + debug_handle_map = self._extract_debug_handles(ep_copy) + + # checking the map still has the same ids, the node may change + self.assertEqual( + set(debug_handle_map.values()), set(debug_handle_map_ref.values()) + ) + + def test_run_decompositions_map_handle_to_new_nodes(self): + test_models = [ + TestHelperModules.TwoLinearModule(), + TestHelperModules.Conv2dThenConv1d(), + ] + + for m in test_models: + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + + self._assert_each_node_has_debug_handle(ep) + pre_decomp_to_debug_handle_map_ref = ( + self._extract_debug_handles_with_prev_decomp_op(ep) + ) + + ep_copy = copy.copy(ep) + ep_copy = ep_copy.run_decompositions() + self._assert_each_node_has_debug_handle(ep_copy) + pre_decomp_to_debug_handle_map = ( + self._extract_debug_handles_with_prev_decomp_op(ep_copy) + ) + + # checking the map still has the same ids, the node may change + self.assertEqual( + pre_decomp_to_debug_handle_map, pre_decomp_to_debug_handle_map_ref + ) + + def test_prepare_for_propagation_comparison(self): + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + m = ep.module() + m_logger = prepare_for_propagation_comparison(m) + ref = m(*example_inputs) + res = m_logger(*example_inputs) + + from torchao.quantization.pt2e_flow.pt2e._numeric_debugger import OutputLogger + + loggers = [m for m in m_logger.modules() if isinstance(m, OutputLogger)] + self.assertEqual(len(loggers), 3) + self.assertTrue("conv2d" in [logger.node_name for logger in loggers]) + self.assertEqual(res, ref) + + def test_extract_results_from_loggers(self): + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + m = ep.module() + m_ref_logger = prepare_for_propagation_comparison(m) + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=False) + ) + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + m_quant_logger = prepare_for_propagation_comparison(m) + + m_ref_logger(*example_inputs) + m_quant_logger(*example_inputs) + ref_results = extract_results_from_loggers(m_ref_logger) + quant_results = extract_results_from_loggers(m_quant_logger) + comparison_results = compare_results(ref_results, quant_results) + for node_summary in comparison_results.values(): + if len(node_summary.results) > 0: + self.assertGreaterEqual(node_summary.results[0].sqnr, 35) + + def test_extract_results_from_loggers_list_output(self): + m = TestHelperModules.Conv2dWithSplit() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + m = ep.module() + m_ref_logger = prepare_for_propagation_comparison(m) + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=False) + ) + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + m_quant_logger = prepare_for_propagation_comparison(m) + + m_ref_logger(*example_inputs) + m_quant_logger(*example_inputs) + ref_results = extract_results_from_loggers(m_ref_logger) + quant_results = extract_results_from_loggers(m_quant_logger) + comparison_results = compare_results(ref_results, quant_results) + for node_summary in comparison_results.values(): + if len(node_summary.results) > 0: + sqnr = node_summary.results[0].sqnr + if isinstance(sqnr, list): + for sqnr_i in sqnr: + self.assertGreaterEqual(sqnr_i, 35) + else: + self.assertGreaterEqual(sqnr, 35) + + def test_added_node_gets_unique_id(self) -> None: + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + ep = export_for_training(m, example_inputs) + generate_numeric_debug_handle(ep) + ref_handles = self._extract_debug_handles(ep) + ref_counter = Counter(ref_handles.values()) + for k, v in ref_counter.items(): + self.assertEqual( + v, + 1, + msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1", + ) + + # Now that we have unique ids, add a new node into the graph and re-generate + # to make sure that the new node gets a unique id. + last_node = next(iter(reversed(ep.graph.nodes))) + with ep.graph.inserting_before(last_node): + arg = last_node.args[0] + self.assertIsInstance(arg, (list, tuple)) + arg = arg[0] + # Add a function that only requires a single tensor input. + n = ep.graph.call_function(torch.ops.aten.relu.default, args=(arg,)) + arg.replace_all_uses_with(n, lambda x: x != n) + ep.graph_module.recompile() + + # Regenerate handles, make sure only the new relu node has a new id, and + # it doesn't clash with any of the existing ids. + generate_numeric_debug_handle(ep) + + self._assert_each_node_has_debug_handle(ep) + handles_after_modification = self._extract_debug_handles(ep) + handles_counter = Counter(handles_after_modification.values()) + for name, handle in ref_handles.items(): + self.assertIn(name, handles_after_modification) + # Check that handle was unchanged. + self.assertEqual(handles_after_modification[name], handle) + # Check that total count was unchanged. + ref_count = ref_counter[handle] + after_count = handles_counter[handle] + self.assertEqual( + after_count, + ref_count, + msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}", + ) + + # Check for relu specifically. Avoid hardcoding the handle id since it + # may change with future node ordering changes. + self.assertNotEqual(handles_after_modification["relu_default"], 0) + self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1) diff --git a/test/quantization/pt2e_flow/test_quantize_pt2e.py b/test/quantization/pt2e_flow/test_quantize_pt2e.py new file mode 100644 index 0000000000..2940586c64 --- /dev/null +++ b/test/quantization/pt2e_flow/test_quantize_pt2e.py @@ -0,0 +1,2598 @@ +# Owner(s): ["oncall: quantization"] +# ruff: noqa: F841 + + +import torch +from torch import Tensor +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.qconfig import ( + QConfig, + default_per_channel_symmetric_qnnpack_qconfig, + per_channel_weight_observer_range_neg_127_to_127, + weight_observer_range_neg_127_to_127, +) +from torch.export import export_for_training +from torch.fx import Node +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, +) +from torch.testing._internal.common_quantization import ( + TestHelperModules, + skipIfNoQNNPACK, +) +from torch.testing._internal.common_utils import ( + TEST_CUDA, + TEST_HPU, + TemporaryFileName, + instantiate_parametrized_tests, + parametrize, + skipIfHpu, +) + +import torchao +from torchao.quantization.pt2e_flow import ObserverOrFakeQuantize, observer +from torchao.quantization.pt2e_flow.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torchao.quantization.pt2e_flow.quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + FixedQParamsQuantizationSpec, + QuantizationAnnotation, + QuantizationSpec, + Quantizer, + SharedQuantizationSpec, +) +from torchao.quantization.pt2e_flow.quantizer.composable_quantizer import ( # noqa: F811 + ComposableQuantizer, +) +from torchao.quantization.pt2e_flow.quantizer.embedding_quantizer import ( # noqa: F811 + EmbeddingQuantizer, +) +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, +) +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer_utils import ( + OP_TO_ANNOTATOR, + QuantizationConfig, +) +from torchao.testing.pt2e.utils import PT2EQuantizationTestCase + + +@skipIfNoQNNPACK +class TestQuantizePT2E(PT2EQuantizationTestCase): + def test_simple_quantizer(self): + # TODO: use OP_TO_ANNOTATOR + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.conv2d.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + bias = node.args[2] + assert isinstance(bias, Node) + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_weight_observer, + ) + bias_qspec = QuantizationSpec( + dtype=torch.float32, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.PlaceholderObserver, + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + bias: bias_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # two for input of the first conv, one for output for the first conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.ConvWithBNRelu(relu=False, bn=False), + example_inputs, + BackendAQuantizer(), + node_occurrence, + node_list, + ) + + def test_wo_annotate_conv_output_quantizer(self): + # TODO: use OP_TO_ANNOTATOR + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_weight_observer, + ) + bias_qspec = QuantizationSpec( + dtype=torch.float32, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.PlaceholderObserver, + ) + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.conv2d.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + bias = node.args[2] + assert isinstance(bias, Node) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + bias: bias_qspec, + }, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + m = torch.nn.Conv2d(2, 2, 1) + x = torch.rand(1, 2, 14, 14) + example_inputs = (x,) + m = self._quantize(m, BackendAQuantizer(), example_inputs) + # Ensure the conv has no observer inserted at output + node_occurrence = { + # two for input of conv + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 1, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 2, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.conv2d.default), + ] + self.checkGraphModuleNodes( + m, expected_node_list=node_list, expected_node_occurrence=node_occurrence + ) + + def test_max_pool2d_quantizer(self): + # TODO: use OP_TO_ANNOTATOR + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_weight_observer, + ) + bias_qspec = QuantizationSpec( + dtype=torch.float32, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.PlaceholderObserver, + ) + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.conv2d.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + bias = node.args[2] + assert isinstance(bias, Node) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + bias: bias_qspec, + }, + _annotated=True, + ) + if ( + node.op == "call_function" + and node.target == torch.ops.aten.max_pool2d.default + ): + maxpool_node = node + input_act = maxpool_node.args[0] + assert isinstance(input_act, Node) + maxpool_node.meta["quantization_annotation"] = ( + QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + }, + output_qspec=SharedQuantizationSpec( + (input_act, maxpool_node) + ), + _annotated=True, + ) + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + m = TestHelperModules.ConvMaxPool2d() + x = torch.rand(1, 2, 14, 14) + example_inputs = (x,) + m = self._quantize(m, BackendAQuantizer(), example_inputs) + node_occurrence = { + # two for input of conv + # one for input of maxpool + # one for output of maxpool + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 3, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 4, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.conv2d.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.max_pool2d.default), + ] + self.checkGraphModuleNodes( + m, expected_node_list=node_list, expected_node_occurrence=node_occurrence + ) + + def test_derived_qspec(self): + # TODO: use OP_TO_ANNOTATOR + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.conv2d.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + bias = node.args[2] + assert isinstance(bias, Node) + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_weight_observer, + ) + + def derive_qparams_fn( + obs_or_fqs: list[ObserverOrFakeQuantize], + ) -> tuple[Tensor, Tensor]: + assert ( + len(obs_or_fqs) == 2 + ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" + act_obs_or_fq = obs_or_fqs[0] + weight_obs_or_fq = obs_or_fqs[1] + act_scale, act_zp = act_obs_or_fq.calculate_qparams() + ( + weight_scale, + weight_zp, + ) = weight_obs_or_fq.calculate_qparams() + return torch.tensor([act_scale * weight_scale]).to( + torch.float32 + ), torch.tensor([0]).to(torch.int32) + + bias_qspec = DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], + derive_qparams_fn=derive_qparams_fn, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_symmetric, + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + bias: bias_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + + m = self._quantize(m, BackendAQuantizer(), example_inputs) + node_occurrence = { + # input, weight, bias, output for the conv + # note: quantize op for weight and bias are const propagated + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 4, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.conv2d.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ] + self.checkGraphModuleNodes( + m, expected_node_list=node_list, expected_node_occurrence=node_occurrence + ) + + def test_derived_qspec_per_channel(self): + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.conv2d.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + bias = node.args[2] + assert isinstance(bias, Node) + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_channel_affine, + is_dynamic=False, + ch_axis=0, + observer_or_fake_quant_ctr=observer.default_per_channel_weight_observer, + ) + + def derive_qparams_fn( + obs_or_fqs: list[ObserverOrFakeQuantize], + ) -> tuple[Tensor, Tensor]: + assert ( + len(obs_or_fqs) == 1 + ), f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}" + weight_obs_or_fq = obs_or_fqs[0] + ( + weight_scale, + weight_zp, + ) = weight_obs_or_fq.calculate_qparams() + return weight_scale, torch.zeros_like(weight_scale) + + bias_qspec = DerivedQuantizationSpec( + derived_from=[(weight, node)], + derive_qparams_fn=derive_qparams_fn, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + bias: bias_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + + m = self._quantize(m, BackendAQuantizer(), example_inputs) + + node_occurrence = { + # input, output for the conv + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 2, + # weight and bias for conv + # note: quantize op for weight and bias are const propagated + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel.default + ): 0, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 2, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ), + ns.call_function(torch.ops.aten.conv2d.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ] + self.checkGraphModuleNodes( + m, expected_node_list=node_list, expected_node_occurrence=node_occurrence + ) + + def test_fixed_qparams_qspec_ptq(self): + self._test_fixed_qparams_qspec(is_qat=False) + + # TODO: refactor and move this to test_quantize_pt2_qat.py + def test_fixed_qparams_qspec_qat(self): + self._test_fixed_qparams_qspec(is_qat=True) + + def _test_fixed_qparams_qspec(self, is_qat: bool): + class M(torch.nn.Module): + def forward(self, x): + return torch.sigmoid(x) + + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.sigmoid.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + act_qspec = FixedQParamsQuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + scale=1.0 / 256.0, + zero_point=0, + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + m = M().eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + + m = self._quantize(m, BackendAQuantizer(), example_inputs, is_qat) + fixed_scale = 1.0 / 256.0 + fixed_zero_point = 0 + for n in m.graph.nodes: + if n.op == "call_function": + if ( + n.target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ): + scale_0 = n.args[1] + zero_point_0 = n.args[2] + if ( + n.target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): + scale_1 = n.args[1] + zero_point_1 = n.args[2] + self.assertEqual(scale_0, fixed_scale) + self.assertEqual(zero_point_0, fixed_zero_point) + self.assertEqual(scale_1, fixed_scale) + self.assertEqual(zero_point_1, fixed_zero_point) + node_occurrence = { + # two for input of the first conv, one for output for the first conv + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 2, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.sigmoid.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ] + self.checkGraphModuleNodes( + m, expected_node_list=node_list, expected_node_occurrence=node_occurrence + ) + + def test_fixed_qparams_qspec_observer_dedup(self): + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + act_qspec = FixedQParamsQuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + scale=1.0 / 256.0, + zero_point=0, + ) + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.sigmoid.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + elif ( + node.op == "call_function" + and node.target == torch.ops.aten.add.Tensor + ): + input_act0 = node.args[0] + assert isinstance(input_act, Node) + input_act1 = node.args[1] + assert isinstance(input_act, Node) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act0: act_qspec, + input_act1: act_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def forward(self, x, y): + return torch.sigmoid(x) + y + + def example_inputs(self): + return ( + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + ) + + m = M().eval() + example_inputs = m.example_inputs() + m = self._quantize(m, BackendAQuantizer(), example_inputs, is_qat=False) + + node_occurrence = { + # two for input of the first conv, one for output for the first conv + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 4, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 4, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.sigmoid.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.add.Tensor), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ] + self.checkGraphModuleNodes( + m, expected_node_list=node_list, expected_node_occurrence=node_occurrence + ) + + def test_shared_qspec(self): + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.conv2d.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + bias = node.args[2] + assert isinstance(bias, Node) + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_weight_observer, + ) + bias_qspec = QuantizationSpec( + dtype=torch.float32, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.PlaceholderObserver, + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + bias: bias_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + elif node.target is torch.ops.aten.cat.default: + cat_node = node + input_nodes = cat_node.args[0] + first_input_node = input_nodes[0] + input_qspec_map = {} + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + input_qspec_map[first_input_node] = act_qspec + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, cat_node) + ) + for input_node in input_nodes[1:]: + input_qspec_map[input_node] = ( + share_qparams_with_input_act0_qspec + ) + + cat_node.meta["quantization_annotation"] = ( + QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act0_qspec, + _annotated=True, + ) + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + m = TestHelperModules.Conv2dWithCat().eval() + example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5)) + + # program capture + m = export_for_training( + m, + example_inputs, + ).module() + m = prepare_pt2e(m, BackendAQuantizer()) + # make sure the two observers for input are shared + conv_output_obs = [] + for n in m.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.conv2d.default: + conv_output_obs.append(getattr(m, next(iter(n.users)).target)) + if n.op == "call_function" and n.target == torch.ops.aten.cat.default: + inputs = n.args[0] + input0 = inputs[0] + input1 = inputs[1] + assert input0.op == "call_module" + assert input1.op == "call_module" + obs_ins0 = getattr(m, input0.target) + obs_ins1 = getattr(m, input1.target) + assert obs_ins0 == obs_ins1 + assert ( + len(conv_output_obs) == 2 + ), "expecting two observer that follows conv2d ops" + # checking that the output observers for the two convs are shared as well + assert conv_output_obs[0] == conv_output_obs[1] + + m(*example_inputs) + m = convert_pt2e(m) + + node_occurrence = { + # two for input of the first conv, one for output for the first conv + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 5, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 7, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.cat.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ] + self.checkGraphModuleNodes( + m, expected_node_list=node_list, expected_node_occurrence=node_occurrence + ) + + def _test_transitive_sharing_with_cat_helper(self, quantizer): + m = TestHelperModules.Conv2dWithTwoCat().eval() + example_inputs = ( + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + torch.randn(1, 6, 3, 3), + torch.randn(1, 6, 3, 3), + ) + + # program capture + m = export_for_training( + m, + example_inputs, + ).module() + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + # make sure the two input observers and output are shared + conv_output_obs = [] + for n in m.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.conv2d.default: + conv_output_obs.append(getattr(m, next(iter(n.users)).target)) + if n.op == "call_function" and n.target == torch.ops.aten.cat.default: + inputs = n.args[0] + input0 = inputs[0] + input1 = inputs[1] + assert input0.op == "call_module" + assert input1.op == "call_module" + obs_ins0 = getattr(m, input0.target) + obs_ins1 = getattr(m, input1.target) + assert obs_ins0 == obs_ins1 + + output_obs = next(iter(n.users)) + assert output_obs.op == "call_module" + obs_ins2 = getattr(m, output_obs.target) + assert obs_ins0 == obs_ins2, "input observer does not match output" + + assert ( + len(conv_output_obs) == 2 + ), "expecting two observer that follows conv2d ops" + # checking that the output observers for the two convs are shared as well + assert conv_output_obs[0] == conv_output_obs[1] + + m(*example_inputs) + m = convert_pt2e(m) + + node_occurrence = { + # two for input of the first conv, one for output for the first conv + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 7, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 9, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.cat.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.cat.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ] + self.checkGraphModuleNodes( + m, expected_node_list=node_list, expected_node_occurrence=node_occurrence + ) + + def test_shared_qspec_transitivity(self): + """This tests the transitivity of SharedQuantizationSpec, that is + if A is shared with B, B is shared with C, then C should be shared with A as well + + x1 -> conv1 -> cat1 -----> cat2 + x2 -> conv2 -/ / + x3 -> add / + x4 / + + both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor + so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same + sharing group after transitive sharing + """ + + # TODO: refactor this to a common util + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.conv2d.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + bias = node.args[2] + assert isinstance(bias, Node) + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_weight_observer, + ) + bias_qspec = QuantizationSpec( + dtype=torch.float32, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.PlaceholderObserver, + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + bias: bias_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + elif node.target is torch.ops.aten.cat.default: + cat_node = node + input_nodes = cat_node.args[0] + first_input_node = input_nodes[0] + input_qspec_map = {} + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + input_qspec_map[first_input_node] = act_qspec + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, cat_node) + ) + for input_node in input_nodes[1:]: + input_qspec_map[input_node] = ( + share_qparams_with_input_act0_qspec + ) + + cat_node.meta["quantization_annotation"] = ( + QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act0_qspec, + _annotated=True, + ) + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + self._test_transitive_sharing_with_cat_helper(BackendAQuantizer()) + + def test_shared_qspec_transitivity_case_2(self): + """This tests the transitivity of SharedQuantizationSpec, that is + if A is shared with B, B is shared with C, then C should be shared with A as well + + x1 -> conv1 -> cat1 -----> cat2 + x2 -> conv2 -/ / + x3 -> add / + x4 / + + both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor + so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same + sharing group after transitive sharing + + the difference is that for this one, all edges and nodes are shared with the second input edge of cat + instead of the first input edge of cat as in previous example + """ + + # TODO: refactor this to a common util + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.conv2d.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + bias = node.args[2] + assert isinstance(bias, Node) + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_weight_observer, + ) + bias_qspec = QuantizationSpec( + dtype=torch.float32, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.PlaceholderObserver, + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + bias: bias_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + elif node.target is torch.ops.aten.cat.default: + cat_node = node + input_nodes = cat_node.args[0] + first_input_node = input_nodes[0] + second_input_node = input_nodes[1] + input_qspec_map = {} + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + input_qspec_map[second_input_node] = act_qspec + share_qparams_with_input_act1_qspec = SharedQuantizationSpec( + (second_input_node, cat_node) + ) + input_qspec_map[first_input_node] = ( + share_qparams_with_input_act1_qspec + ) + + cat_node.meta["quantization_annotation"] = ( + QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act1_qspec, + _annotated=True, + ) + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + self._test_transitive_sharing_with_cat_helper(BackendAQuantizer()) + + def test_allow_implicit_sharing(self): + """This tests the allow_transitive_sharing flag of QuantizationAnnotation, that is + if a node is configured with allow_implicit_sharing=False, we will not have implicit sharing + for node and (node, consumer) even they refer to the same Tensor + + x1 -> add1 -----> add3 + x2 -/ / + x3 -> add2 / + x4 -/ + + all add has shared input and output, and second input is using shared quantization spec pointing + to first input, but we set allow_implicit_sharing to False for all add nodes so input and output of add1, + add2 and add3 will each belong to one sharing group, so we'll have: + + x1 -> obs1 -> add1 -> obs1 -> obs3--> add3 -> obs3 + x2 -> obs1 -/ / + x3 -> obs2 -> add2 -> obs2 -> obs3 + x4 -> obs2 -/ + """ + + # TODO: refactor this to a common util + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if node.target is torch.ops.aten.add.Tensor: + add_node = node + first_input_node = add_node.args[0] + second_input_node = add_node.args[1] + input_qspec_map = {} + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + input_qspec_map[second_input_node] = act_qspec + share_qparams_with_input_act1_qspec = SharedQuantizationSpec( + (second_input_node, add_node) + ) + input_qspec_map[first_input_node] = ( + share_qparams_with_input_act1_qspec + ) + + add_node.meta["quantization_annotation"] = ( + QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act1_qspec, + allow_implicit_sharing=False, + _annotated=True, + ) + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + m = TestHelperModules.ThreeAdd().eval() + example_inputs = ( + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + ) + + # program capture + m = export_for_training( + m, + example_inputs, + ).module() + quantizer = BackendAQuantizer() + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + observers = [] + for n in m.graph.nodes: + if n.target == torch.ops.aten.add.Tensor: + input_obs1 = getattr(m, n.args[0].target) + input_obs2 = getattr(m, n.args[1].target) + output_obs = getattr(m, next(iter(n.users)).target) + self.assertIs(input_obs1, input_obs2) + self.assertIs(input_obs1, output_obs) + observers.append(input_obs1) + assert len(observers) == 3 + self.assertIsNot(observers[0], observers[1]) + self.assertIsNot(observers[0], observers[2]) + self.assertIsNot(observers[1], observers[2]) + + @skipIfHpu + @parametrize("dtype", (torch.float32, torch.bfloat16)) + @parametrize("quant_dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn)) + def test_quantization_dtype(self, dtype, quant_dtype): + class DtypeActQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + info_fun = torch.iinfo if quant_dtype == torch.int16 else torch.finfo + activate_qspec = QuantizationSpec( + dtype=quant_dtype, + quant_min=int(info_fun(quant_dtype).min), + quant_max=int(info_fun(quant_dtype).max), + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + int8_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_weight_observer, + ) + quantization_config = QuantizationConfig( + input_activation=activate_qspec, + weight=int8_qspec, + bias=None, + output_activation=activate_qspec, + ) + OP_TO_ANNOTATOR["conv"](model, quantization_config) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def __init__(self, dtype): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3, dtype=dtype) + + def forward(self, x): + return self.conv(x) + + quantizer = DtypeActQuantizer() + node_occurrence = { + # one for input of the first conv, one for output for the first conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + example_inputs = (torch.randn(1, 3, 3, 3, dtype=dtype),) + m = self._test_quantizer( + M(dtype).eval(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def verify_quant_dequant_iotypes(m): + for node in m.graph.nodes: + if ( + node.op == "call_function" + and node.target.__name__ == "dequantize_per_tensor.default" + ): + # Check dequantize node + dequant_node = node + dequant_in_dtype = dequant_node.args[5] + dequant_out_dtype = torch.float32 + if "out_dtype" in dequant_node.kwargs: + dequant_out_dtype = dequant_node.kwargs["out_dtype"] + + # Check preceding quantize node + # Depending on fold_quantize flag, quantize node may be absent + quant_node = node.args[0] + if ( + quant_node.op == "call_function" + and quant_node.target.__name__ == "quantize_per_tensor.default" + ): + quant_in_dtype = torch.float32 + if "val" in quant_node.args[0].meta: + quant_in_dtype = quant_node.args[0].meta["val"].dtype + quant_out_dtype = quant_node.args[5] + assert ( + quant_in_dtype == dequant_out_dtype + and quant_out_dtype == dequant_in_dtype + ), "quant dequant io dtype check failed!" + + verify_quant_dequant_iotypes(m) + + def test_input_edge_sanity_check(self): + class M(torch.nn.Module): + def forward(self, x): + return x + 6 + + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.add.Tensor + ): + input_act1 = node.args[0] + # this is a constant, so not valid for annotation + input_act2 = node.args[1] + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act1: act_qspec, + # this is supposed to error out + input_act2: act_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + m = M().eval() + example_inputs = torch.randn(1, 2, 3, 3) + m = export_for_training(m, (example_inputs,)).module() + with self.assertRaises(Exception): + m = prepare_pt2e(m, BackendAQuantizer()) + + def test_fold_quantize(self): + """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)""" + m = self._get_pt2e_quantized_linear() + node_occurrence = { + # quantize op for weight node is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 3, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_fold_quantize_per_channel(self): + """Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded)""" + m = self._get_pt2e_quantized_linear(is_per_channel=True) + node_occurrence = { + # quantize op for weight node is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 2, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_dont_fold_other_constant(self): + """Make sure the constant propagation does not apply to things unrelated to + quantization + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + self.dont_fold_me = torch.nn.Parameter(torch.randn(2, 2)) + + def forward(self, x): + t = self.dont_fold_me.t() + return self.linear(x) + t + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + # only quantize linear, so add is not quantized and the constant Tensor + # should not be folded + quantizer.set_module_type(torch.nn.Linear, operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = self._quantize(m, quantizer, example_inputs) + node_occurrence = { + # quantize op for weight node is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 3, + # transpose op not folded + ns.call_function(torch.ops.aten.t.default): 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_fold_all_ops_before_quantize(self): + """Test folding all ops that's before quantized operator: + Before: + get_attr(weight) -> transpose -> quantize -> dequantize + After: + get_attr(folded_weight) -> dequantize + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.randn(2, 2) + + def forward(self, x): + t = self.weight.t() + return torch.nn.functional.linear(x, t) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = self._quantize(m, quantizer, example_inputs) + node_occurrence = { + # quantize op for weight node is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 3, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_constant_prop_preserve_metadata(self): + """Test to make sure the get_attr node for const propagated weight Tensor gets the correct + metadata (from original get_attr node from weight) + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config() + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = export_for_training( + m, + example_inputs, + ).module() + weight_meta = None + for n in m.graph.nodes: + if ( + n.op == "get_attr" + and next(iter(n.users)).target == torch.ops.aten.linear.default + ): + weight_meta = n.meta + break + assert weight_meta is not None, "Expect to find metadata for weight node" + + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + + for n in m.graph.nodes: + if n.op == "get_attr" and "frozen_param" in n.target: + for key in n.meta: + self.assertEqual(n.meta[key], weight_meta[key]) + + def test_save_load(self): + """Test save/load a quantized model""" + m = self._get_pt2e_quantized_linear() + example_inputs = (torch.randn(2, 2),) + ref_res = m(*example_inputs) + + with TemporaryFileName() as fname: + # serialization + quantized_ep = torch.export.export(m, example_inputs, strict=True) + torch.export.save(quantized_ep, fname) + # deserialization + loaded_ep = torch.export.load(fname) + loaded_quantized_model = loaded_ep.module() + res = loaded_quantized_model(*example_inputs) + self.assertEqual(ref_res, res) + + def test_composable_quantizer_throw(self): + class BadQuantizer(Quantizer): + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in gm.graph.nodes: + n.meta["quantization_annotation"] = None + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + bad_quantizer = BadQuantizer() + composable_quantizer = ComposableQuantizer([quantizer, bad_quantizer]) + m_eager = TestHelperModules.ConvLinearWPermute().eval() + example_inputs = (torch.randn(2, 3, 4, 4),) + self.assertRaises( + RuntimeError, + lambda: self._test_quantizer( + m_eager, example_inputs, composable_quantizer, {} + ), + ) + + def test_transform_for_annotation(self): + class TestQuantizer(Quantizer): + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + # Make a copy of the graph to ensure that we are using the + # return value of this function. + graph = torch.fx.Graph() + graph.graph_copy(model.graph, {}) + for n in graph.nodes: + if n.target == torch.ops.aten.add.Tensor: + n.target = torch.ops.aten.mul.Tensor + model = torch.fx.GraphModule(model, graph) + return model + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def forward(self, x): + return x + 3 + + m = M().eval() + quantizer = TestQuantizer() + example_inputs = (torch.randn(1, 2, 3, 3),) + m = export_for_training(m, example_inputs).module() + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + node_occurrence = { + ns.call_function(torch.ops.aten.add.Tensor): 0, + ns.call_function(torch.ops.aten.mul.Tensor): 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_composable_quantizer_transform_for_annotation(self): + class TestQuantizer1(Quantizer): + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + for n in model.graph.nodes: + if n.target == torch.ops.aten.add.Tensor: + n.target = torch.ops.aten.mul.Tensor + return model + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class TestQuantizer2(Quantizer): + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + for n in model.graph.nodes: + if n.target == torch.ops.aten.sub.Tensor: + n.target = torch.ops.aten.div.Tensor + return model + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def forward(self, x, y, z): + return x + y - z + + m = M().eval() + quantizer = ComposableQuantizer([TestQuantizer1(), TestQuantizer2()]) + example_inputs = ( + torch.randn(1, 2, 3, 3), + torch.randn(1, 2, 3, 3), + torch.randn(1, 2, 3, 3), + ) + m = export_for_training(m, example_inputs).module() + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + node_occurrence = { + ns.call_function(torch.ops.aten.add.Tensor): 0, + ns.call_function(torch.ops.aten.sub.Tensor): 0, + ns.call_function(torch.ops.aten.mul.Tensor): 1, + ns.call_function(torch.ops.aten.div.Tensor): 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_embedding_quantizer(self): + m_eager = TestHelperModules.EmbeddingModule().eval() + indices = torch.tensor( + [ + 9, + 6, + 5, + 7, + 8, + 8, + 9, + 2, + 8, + 6, + 6, + 9, + 1, + 6, + 8, + 8, + 3, + 2, + 3, + 6, + 3, + 6, + 5, + 7, + 0, + 8, + 4, + 6, + 5, + 8, + 2, + 3, + ] + ) + example_inputs = (indices,) + + quantizer = EmbeddingQuantizer() + node_occurrence = { + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.aten.embedding.default, + ] + act_affine_quant_obs = torch.ao.quantization.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + dynamic_qconfig = QConfig( + activation=act_affine_quant_obs, + weight=torch.ao.quantization.per_channel_weight_observer_range_neg_127_to_127, + ) + qconfig = torch.ao.quantization.default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig) + qconfig_mapping = qconfig_mapping.set_object_type( + torch.nn.Embedding, torch.ao.quantization.float_qparams_weight_only_qconfig + ) + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + node_list, + True, + qconfig_mapping, + ) + + def test_composable_quantizer_linear_conv(self): + dynamic_quantizer = XNNPACKQuantizer() + quantization_config_dynamic = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=True + ) + dynamic_quantizer.set_global(quantization_config_dynamic) + static_quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + static_quantizer.set_global(quantization_config) + # Note that dynamic quantization must be applied first here. + # this is because static quantizer also quantizes linear with static qspec + # and if we apply static_quantizer first then dynamic_quantizer cannot be applied + composable_quantizer = ComposableQuantizer( + [dynamic_quantizer, static_quantizer] + ) + m_eager = TestHelperModules.ConvLinearWPermute().eval() + + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + act_affine_quant_obs = ( + torch.ao.quantization.observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + ) + dynamic_qconfig = QConfig( + activation=act_affine_quant_obs, + weight=weight_observer_range_neg_127_to_127, + ) + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig) + # Had to turn off check against fx because fx quant workflow does not seem + # to propagate observers for permute node for this model. + # Suprisingly it does propagate it for EmbeddingConvLinearModule + # TODO: Figure out the right behavior for propagation + self._test_quantizer( + m_eager, + example_inputs, + composable_quantizer, + node_occurrence, + [], + False, + qconfig_mapping, + ) + + def test_embedding_conv_linear_quantization(self): + m_eager = TestHelperModules.EmbeddingConvLinearModule().eval() + indices = torch.tensor( + [ + 9, + 6, + 5, + 7, + 8, + 8, + 9, + 2, + 8, + 6, + 6, + 9, + 1, + 6, + 8, + 8, + 3, + 2, + 3, + 6, + 3, + 6, + 5, + 7, + 0, + 8, + 4, + 6, + 5, + 8, + 2, + 3, + ] + ) + indices = torch.unsqueeze(indices, 0) + example_inputs = (indices,) + + embedding_quantizer = EmbeddingQuantizer() + dynamic_quantizer = XNNPACKQuantizer() + quantization_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + dynamic_quantizer.set_global(quantization_config_dynamic) + static_quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + static_quantizer.set_global(quantization_config) + composed_quantizer = ComposableQuantizer( + [embedding_quantizer, dynamic_quantizer, static_quantizer] + ) + + act_affine_quant_obs = ( + torch.ao.quantization.observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + ) + dynamic_qconfig = QConfig( + activation=act_affine_quant_obs, + weight=per_channel_weight_observer_range_neg_127_to_127, + ) + qconfig = torch.ao.quantization.default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig) + qconfig_mapping = qconfig_mapping.set_object_type( + torch.nn.Embedding, torch.ao.quantization.float_qparams_weight_only_qconfig + ) + + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + self._test_quantizer( + m_eager, + example_inputs, + composed_quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def _get_node(self, m: torch.fx.GraphModule, target: torch._ops.OpOverload): + """ + Return the first node matching the specified target, throwing an exception + if no such batch norm node is found. + """ + for n in m.graph.nodes: + if n.target == target: + return n + raise ValueError("Did not find node with target ", target) + + def _test_move_exported_model_dropout(self, inplace: bool): + """ + Test switching dropout behavior between train and eval modes using + `move_exported_model_to_eval` and `move_exported_model_to_train` APIs. + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.dropout = torch.nn.Dropout(0.5, inplace=inplace) + + def forward(self, x): + return self.dropout(x) + + example_inputs = (torch.randn(1),) + m = M().train() + m = export_for_training(m, example_inputs).module() + if inplace: + target = torch.ops.aten.dropout_.default + else: + target = torch.ops.aten.dropout.default + + # Assert that dropout op exists and is in train mode + dropout_node = self._get_node(m, target) + self.assertTrue(dropout_node is not None) + self.assertTrue(dropout_node.args[2]) + + # Move to eval + torchao.quantization.pt2e_flow.move_exported_model_to_eval(m) + + # Assert that dropout op is now in eval mode + dropout_node = self._get_node(m, target) + self.assertTrue(dropout_node is not None) + self.assertTrue(not dropout_node.args[2]) + + # Move back to train + torchao.quantization.pt2e_flow.move_exported_model_to_train(m) + + # Assert that dropout op is now in train mode again + dropout_node = self._get_node(m, target) + self.assertTrue(dropout_node is not None) + self.assertTrue(dropout_node.args[2]) + + def test_move_exported_model_dropout(self): + self._test_move_exported_model_dropout(inplace=False) + + def test_move_exported_model_dropout_inplace(self): + self._test_move_exported_model_dropout(inplace=True) + + def _get_bn_train_eval_ops(self): + return ( + torch.ops.aten.batch_norm.default, + torch.ops.aten.batch_norm.default, + ) + + @parametrize( + "device", + ["cpu"] + (["cuda"] if TEST_CUDA else []) + (["hpu"] if TEST_HPU else []), + ) + def test_move_exported_model_bn(self, device): + """ + Test switching batch_norm behavior between train and eval modes using + `move_exported_model_to_eval` and `move_exported_model_to_train` APIs. + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + return self.bn(x) + + if TEST_CUDA or TEST_HPU: + m = M().train().to(device) + example_inputs = (torch.randn((1, 3, 3, 3), device=device),) + + else: + m = M().train() + example_inputs = (torch.randn(1, 3, 3, 3),) + bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() + m = export_for_training(m, example_inputs).module() + + # Assert that batch norm op exists and is in train mode + bn_node = self._get_node(m, bn_train_op) + self.assertTrue(bn_node is not None) + self.assertTrue(bn_node.args[5]) + + # Move to eval + torchao.quantization.pt2e_flow.move_exported_model_to_eval(m) + + # Assert that batch norm op is now in eval mode + bn_node = self._get_node(m, bn_eval_op) + self.assertTrue(bn_node is not None) + + # Move to train + torchao.quantization.pt2e_flow.move_exported_model_to_train(m) + + # Assert that batch norm op is now in train mode again + bn_node = self._get_node(m, bn_train_op) + self.assertTrue(bn_node is not None) + self.assertTrue(bn_node.args[5]) + + def test_disallow_eval_train(self): + m = TestHelperModules.ConvWithBNRelu(relu=True) + example_inputs = (torch.rand(3, 3, 5, 5),) + + # Before export: this is OK + m.eval() + m.train() + + # After export: this is not OK + m = export_for_training(m, example_inputs).module() + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After prepare: still not OK + quantizer = XNNPACKQuantizer() + m = prepare_qat_pt2e(m, quantizer) + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After convert: still not OK + m = convert_pt2e(m) + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + @skipIfHpu + def test_allow_exported_model_train_eval(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, x): + x = self.bn(x) + x = self.dropout(x) + return x + + if TEST_CUDA: + m = M().train().cuda() + example_inputs = (torch.randn(1, 3, 3, 3).cuda(),) + else: + m = M().train() + example_inputs = (torch.randn(1, 3, 3, 3),) + bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() + m = export_for_training(m, example_inputs).module() + + def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool): + targets = [n.target for n in m.graph.nodes] + bn_op = bn_train_op if train else bn_eval_op + bn_node = self._get_node(m, bn_op) + self.assertTrue(bn_node is not None) + if TEST_CUDA: + self.assertEqual(bn_node.args[5], train) + dropout_node = self._get_node(m, torch.ops.aten.dropout.default) + self.assertEqual(dropout_node.args[2], train) + + # Before wrapping: this is not OK + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After wrapping: does not error and swaps the ops accordingly + torchao.quantization.pt2e_flow.allow_exported_model_train_eval(m) + m.eval() + _assert_ops_are_correct(m, train=False) + m.train() + _assert_ops_are_correct(m, train=True) + + # After prepare but before wrapping: this is not OK + quantizer = XNNPACKQuantizer() + m = prepare_qat_pt2e(m, quantizer) + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After prepare and after wrapping: does not error and swaps the ops accordingly + torchao.quantization.pt2e_flow.allow_exported_model_train_eval(m) + m.eval() + _assert_ops_are_correct(m, train=False) + m.train() + _assert_ops_are_correct(m, train=True) + + # After convert but before wrapping: this is not OK + m = convert_pt2e(m, fold_quantize=True) + with self.assertRaises(NotImplementedError): + m.eval() + with self.assertRaises(NotImplementedError): + m.train() + + # After convert and after wrapping: does not error and swaps the ops accordingly + torchao.quantization.pt2e_flow.allow_exported_model_train_eval(m) + m.eval() + _assert_ops_are_correct(m, train=False) + m.train() + _assert_ops_are_correct(m, train=True) + + def test_allow_exported_model_train_eval_idempotent(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.bn(x) + return x + + m = M().train() + example_inputs = (torch.randn(1, 3, 3, 3),) + m = export_for_training(m, example_inputs).module() + torchao.quantization.pt2e_flow.allow_exported_model_train_eval(m) + + # Mock m.recompile() to count how many times it's been called + m._recompile_count = 0 + + def _fake_recompile(): + m._recompile_count += 1 + + m.recompile = _fake_recompile + + # First train after export should always recompile + m.train() + self.assertNotEqual(m._recompile_count, 0) + count1 = m._recompile_count + + # Train -> train should not recompile + m.train() + self.assertEqual(m._recompile_count, count1) + + # Train -> eval should recompile + m.eval() + self.assertNotEqual(m._recompile_count, count1) + count2 = m._recompile_count + + # Eval -> eval should not recompile + m.eval() + self.assertEqual(m._recompile_count, count2) + + def test_model_is_exported(self): + m = TestHelperModules.ConvWithBNRelu(relu=True) + example_inputs = (torch.rand(3, 3, 5, 5),) + exported_gm = export_for_training(m, example_inputs).module() + fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs) + self.assertTrue( + torchao.quantization.pt2e_flow.pt2e.export_utils.model_is_exported( + exported_gm + ) + ) + self.assertFalse( + torchao.quantization.pt2e_flow.pt2e.export_utils.model_is_exported( + fx_traced_gm + ) + ) + self.assertFalse( + torchao.quantization.pt2e_flow.pt2e.export_utils.model_is_exported(m) + ) + + def test_reentrant(self): + """Test we can safely call quantization apis multiple times""" + m = TestHelperModules.ConvBnReLU2dAndLinearReLU() + example_inputs = (torch.randn(3, 3, 10, 10),) + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=True, is_qat=True) + ) + m.conv_bn_relu = export_for_training(m.conv_bn_relu, example_inputs).module() + m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) + m(*example_inputs) + m.conv_bn_relu = convert_pt2e(m.conv_bn_relu) + + quantizer = XNNPACKQuantizer().set_module_type( + torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False) + ) + m = export_for_training(m, example_inputs).module() + m = prepare_pt2e(m, quantizer) + m = convert_pt2e(m) + + node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 4, + # one for weight + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 5, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.conv2d.default), + ns.call_function(torch.ops.aten.relu.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(torch.ops.aten.linear.default), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ] + self.checkGraphModuleNodes( + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list + ) + + def test_groupwise_per_channel_quant(self): + m = TestHelperModules.GroupwiseConv2d() + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + example_inputs = m.example_inputs() + m = self._quantize(m, quantizer, example_inputs) + # make sure it runs + m(*example_inputs) + + def test_observer_callback(self): + from torch.library import Library, impl + + test_lib = Library("test_int4", "DEF") # noqa: TOR901 + test_lib.define( + "quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor" + ) + + @impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") + def quantize_per_tensor_int4( + input: torch.Tensor, + scale: float, + zero_point: int, + ) -> torch.Tensor: + inv_scale = 1.0 / scale + return ( + torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15) + .to(torch.uint8) + .view(torch.bits8) + ) + + test_lib.define( + "dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor" + ) + + @impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") + def dequantize_per_tensor_int4( + input: torch.Tensor, + scale: float, + zero_point: int, + ) -> torch.Tensor: + return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale + + from torchao.quantization.pt2e_flow.observer import ObserverBase + + class Int4Observer(ObserverBase): + def __init__(self, *args, **kwargs): + # just faking a dtype here + super().__init__(dtype=torch.int8) + + def forward(self, x): + return x + + def calculate_qparams(self, **kwargs): + pass + + def convert(self, model: torch.fx.GraphModule, observer_node: Node): + with model.graph.inserting_before(observer_node): + q_node = model.graph.call_function( + torch.ops.test_int4.quantize_per_tensor_int4, + (observer_node.args[0], 1.0, 0), + {}, + ) + dq_node = model.graph.call_function( + torch.ops.test_int4.dequantize_per_tensor_int4, + (q_node, 1.0, 0), + {}, + ) + observer_node.replace_all_uses_with(dq_node) + model.graph.erase_node(observer_node) + + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.add.Tensor + ): + input_act0 = node.args[0] + assert isinstance(input_act0, Node) + input_act1 = node.args[1] + assert isinstance(input_act1, Node) + + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=Int4Observer, + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act0: act_qspec, + input_act1: act_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def forward(self, x1, x2): + return x1 + x2 + + example_inputs = ( + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + ) + node_occurrence = { + # two for input of the first conv, one for output for the first conv + torch.ops.test_int4.quantize_per_tensor_int4: 3, + torch.ops.test_int4.dequantize_per_tensor_int4: 3, + } + node_list = [ + torch.ops.test_int4.dequantize_per_tensor_int4, + torch.ops.test_int4.dequantize_per_tensor_int4, + torch.ops.aten.add.Tensor, + torch.ops.test_int4.quantize_per_tensor_int4, + ] + self._test_quantizer( + M().eval(), + example_inputs, + BackendAQuantizer(), + node_occurrence, + node_list, + ) + + def test_speed(self): + import time + + def dynamic_quantize_pt2e(model, example_inputs): + torch._dynamo.reset() + model = export_for_training(model, example_inputs).module() + # Per channel quantization for weight + # Dynamic quantization for activation + # Please read a detail: https://fburl.com/code/30zds51q + embedding_quantizer = EmbeddingQuantizer() + dynamic_quantizer = XNNPACKQuantizer() + operator_config_dynamic = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + dynamic_quantizer.set_global(operator_config_dynamic) + composed_quantizer = ComposableQuantizer( + [embedding_quantizer, dynamic_quantizer] + ) + prev = time.time() + model = prepare_qat_pt2e(model, composed_quantizer) + cur = time.time() + # print("prepare time:", cur - prev) + # Without Calibraiton, scale/zero value will have an initialized value of 1.0 + # Per channel quantization needs a proper scale/zero shape/value to work properly. + # So we need to run calibration before converting to quantized model. + model(*example_inputs) + prev = time.time() + model = convert_pt2e(model) + cur = time.time() + # uncomment to see the time + # print("convert time:", cur - prev) + return model + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + m = M().eval() + example_inputs = (torch.randn(5, 5),) + _ = dynamic_quantize_pt2e(m, example_inputs) + + def test_conv_transpose_bn_relu(self): + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + int8_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_weight_observer, + ) + quantization_config = QuantizationConfig( + input_activation=int8_qspec, + weight=int8_qspec, + bias=None, + output_activation=int8_qspec, + ) + # conv_transpose + bn is fused automatically in PTQ (not configurable) + # so we just need to annotate conv_transpose + relu for conv_transpose + bn + relu + # pattern + OP_TO_ANNOTATOR["conv_transpose_relu"](model, quantization_config) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # two for input of the first conv, one for output for the first conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.relu.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.ConvTWithBNRelu(relu=True, bn=True), + example_inputs, + BackendAQuantizer(), + node_occurrence, + node_list, + ) + + def test_multi_users_without_output_observer(self): + """ + Test the case in which a node is used by multiple users, + and had its output observer removed. + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + x = self.conv(x) + return x, x + 1 + + example_inputs = (torch.randn(1, 3, 5, 5),) + m = M() + m = export_for_training(m, example_inputs).module() + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(), + ) + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + + # Remove output observer + observer_to_remove = None + for n in m.graph.nodes: + if n.op == "output": + observer_to_remove = n.args[0][0] + assert observer_to_remove.op == "call_module" + assert observer_to_remove.target.startswith("activation_post_process_") + break + assert observer_to_remove is not None + observer_to_remove.replace_all_uses_with(observer_to_remove.args[0]) + m.graph.erase_node(observer_to_remove) + m.recompile() + + # Convert should succeed + m = convert_pt2e(m) + m(*example_inputs) + + def test_prepare_obs_or_fq_callback(self): + class Model(torch.nn.Module): + def forward(self, x): + x = torch.nn.functional.max_pool2d(x, 2, 2) + x = torch.nn.functional.pixel_shuffle(x, 2) + return x.permute(0, 2, 3, 1) + + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + for node in model.graph.nodes: + if node.op == "call_function" and node.target in ( + torch.ops.aten.max_pool2d.default, + torch.ops.aten.permute.default, + torch.ops.aten.pixel_shuffle.default, + ): + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + node.args[0]: act_qspec, + }, + output_qspec=SharedQuantizationSpec((node.args[0], node)), + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + def prepare_obs_or_fq_callback( + self, + model: torch.fx.GraphModule, + edge_or_node_to_obs_or_fq: dict[EdgeOrNode, ObserverOrFakeQuantize], + ) -> None: + # hard code output quant by updating entire sharing group + output_node = next(n for n in model.graph.nodes if n.op == "output") + output_value = output_node.args[0][0] + old_observer = edge_or_node_to_obs_or_fq[output_value] + sharing_group = [ + k for k, v in edge_or_node_to_obs_or_fq.items() if v is old_observer + ] + new_observer = observer.FixedQParamsObserver( + scale=0.125, + zero_point=42, + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + ) + for x in sharing_group: + edge_or_node_to_obs_or_fq[x] = new_observer + + example_inputs = (torch.rand(1, 32, 16, 16),) + gm = export_for_training(Model().eval(), example_inputs).module() + gm = prepare_pt2e(gm, BackendAQuantizer()) + gm = convert_pt2e(gm) + for n in gm.graph.nodes: + if n.op == "call_function" and n.target in ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ): + # Entire graph share the same qspec which was overriden by FixedQParamsObserver + self.assertEqual(n.args[1], 0.125) + self.assertEqual(n.args[2], 42) + + def test_preserve_nn_module_stack(self): + """Test we can preserve nn_module_stack on replaced pattern's nodes""" + m = TestHelperModules.ConvBnReLU2dAndLinearReLU() + example_inputs = (torch.randn(3, 3, 10, 10),) + + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_per_channel=True, is_qat=True) + ) + + def check_nn_module(node): + self.assertTrue("nn_module_stack" in node.meta) + self.assertTrue( + "ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1] + ) + + m.conv_bn_relu = export_for_training(m.conv_bn_relu, example_inputs).module() + for node in m.conv_bn_relu.graph.nodes: + if node.op not in ["placeholder", "output", "get_attr"]: + check_nn_module(node) + m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) + for node in m.conv_bn_relu.graph.nodes: + if node.name == "mul": + check_nn_module(node) + + +@skipIfNoQNNPACK +class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase): + def test_channel_group_quantization(self): + from torchao.quantization.pt2e_flow.observer import ( + MappingType, + PerGroup, + PerToken, + ) + from torchao.quantization.pt2e_flow.pt2e._affine_quantization import ( + AffineQuantizedMinMaxObserver, + ) + from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 + + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.linear.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=None, + is_dynamic=False, + observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( + # TODO: maybe align the arg name here + target_dtype=torch.uint8, + mapping_type=MappingType.SYMMETRIC, + granularity=PerToken(), + ), + ) + + weight_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=None, + is_dynamic=False, + observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( + target_dtype=torch.uint8, + mapping_type=MappingType.SYMMETRIC, + granularity=PerGroup(group_size=128), + ), + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + }, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 20) + + def forward(self, x): + return self.linear(x) + + node_occurrence = { + torch.ops.torchao_quant.quantize_affine: 1 + if TORCH_VERSION_AT_LEAST_2_7 + else 2, + torch.ops.torchao_quant.dequantize_affine: 2, + } + node_list = [ + torch.ops.torchao_quant.quantize_affine, + torch.ops.torchao_quant.dequantize_affine, + ] + example_inputs = (torch.randn(5, 128),) + self._test_quantizer( + M().eval(), + example_inputs, + BackendAQuantizer(), + node_occurrence, + node_list, + is_debug_mode=True, + ) + + +instantiate_parametrized_tests(TestQuantizePT2E) diff --git a/test/quantization/pt2e_flow/test_quantize_pt2e_qat.py b/test/quantization/pt2e_flow/test_quantize_pt2e_qat.py new file mode 100644 index 0000000000..2d575dc140 --- /dev/null +++ b/test/quantization/pt2e_flow/test_quantize_pt2e_qat.py @@ -0,0 +1,1161 @@ +# Owner(s): ["oncall: quantization"] +import copy +import operator +import unittest +from typing import Any, Optional + +import torch +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.quantize_fx import prepare_qat_fx +from torch.export import export_for_training +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, +) +from torch.testing._internal.common_quantization import ( + QuantizationTestCase, + skip_if_no_torchvision, + skipIfNoQNNPACK, +) +from torch.testing._internal.common_quantized import override_quantized_engine + +import torchao +from torchao.quantization.pt2e_flow import ( + FusedMovingAvgObsFakeQuantize, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + default_fake_quant, +) +from torchao.quantization.pt2e_flow.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torchao.quantization.pt2e_flow.quantizer import ( + DerivedQuantizationSpec, + QuantizationAnnotation, + QuantizationSpec, + Quantizer, +) +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, +) + + +class PT2EQATTestCase(QuantizationTestCase): + """ + Base QuantizationTestCase for PT2E QAT with some helper methods. + """ + + class _BaseConvBnModel(torch.nn.Module): + def __init__( + self, + conv_class: type[torch.nn.Module], + bn_class: type[torch.nn.Module], + has_conv_bias: bool, + has_bn: bool, + has_relu: bool, + **conv_kwargs, + ): + super().__init__() + conv_kwargs.setdefault("in_channels", 3) + conv_kwargs.setdefault("out_channels", 3) + conv_kwargs.setdefault("kernel_size", 3) + conv_kwargs.setdefault("bias", has_conv_bias) + self.conv = conv_class(**conv_kwargs) + self.bn = bn_class(conv_kwargs["out_channels"]) if has_bn else None + self.relu = torch.nn.ReLU() if has_relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + def _get_conv_bn_model( + self, + has_conv_bias: bool = True, + has_bn: bool = True, + has_relu: bool = False, + transpose: bool = False, + **conv_kwargs, + ): + """ + Return an instance of a simple test model containing the + conv[-bn][-relu] pattern. By default, this returns a + conv-bn model with conv bias. + """ + return self._BaseConvBnModel( + self.conv_transpose_class if transpose else self.conv_class, + self.bn_class, + has_conv_bias, + has_bn, + has_relu, + **conv_kwargs, + ) + + def _verify_symmetric_xnnpack_qat_numerics( + self, + model: torch.nn.Module, + example_inputs: tuple[Any, ...], + ): + self._verify_symmetric_xnnpack_qat_numerics_helper( + model, + example_inputs, + is_per_channel=True, + ) + self._verify_symmetric_xnnpack_qat_numerics_helper( + model, + example_inputs, + is_per_channel=False, + ) + + def _verify_symmetric_xnnpack_qat_numerics_helper( + self, + model: torch.nn.Module, + example_inputs: tuple[Any, ...], + is_per_channel: bool, + verify_convert: bool = True, + ): + """ + Helper method to verify that the QAT numerics for PT2E quantization match those of + FX graph mode quantization for symmetric qnnpack. + """ + # resetting dynamo cache + torch._dynamo.reset() + MANUAL_SEED = 100 + + # PT2 export + + model_pt2e = copy.deepcopy(model) + quantizer = XNNPACKQuantizer() + quantizer.set_global( + get_symmetric_quantization_config( + is_per_channel=is_per_channel, is_qat=True + ) + ) + model_pt2e = export_for_training( + model_pt2e, + example_inputs, + ).module() + model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer) + torch.manual_seed(MANUAL_SEED) + after_prepare_result_pt2e = model_pt2e(*example_inputs) + + model_fx = copy.deepcopy(model) + # TODO: remove the test if fx quant is removed from pytorch + if is_per_channel: + default_qconfig = ( + torch.ao.quantization.default_per_channel_symmetric_qnnpack_qat_qconfig + ) + else: + default_qconfig = ( + torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig + ) + qconfig_mapping = QConfigMapping().set_global(default_qconfig) + backend_config = ( + torch.ao.quantization.backend_config.get_qnnpack_backend_config() + ) + model_fx = prepare_qat_fx( + model_fx, qconfig_mapping, example_inputs, backend_config=backend_config + ) + torch.manual_seed(MANUAL_SEED) + after_prepare_result_fx = model_fx(*example_inputs) + + # Verify that numerics match + print("model pt2e:", model_pt2e) + print("model fx:", model_fx) + print("diff:", after_prepare_result_pt2e - after_prepare_result_fx) + self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx) + + if verify_convert: + # We don't want to impose any ordering requirements between move_exported_model_to_eval and convert_pt2e + torchao.quantization.pt2e_flow.move_exported_model_to_eval(model_pt2e) + model_pt2e = convert_pt2e(model_pt2e) + quant_result_pt2e = model_pt2e(*example_inputs) + model_fx.eval() + model_fx = ( + torch.ao.quantization.quantize_fx._convert_to_reference_decomposed_fx( + model_fx, + backend_config=backend_config, + ) + ) + quant_result_fx = model_fx(*example_inputs) + self.assertEqual(quant_result_pt2e, quant_result_fx) + + def _verify_symmetric_xnnpack_qat_graph( + self, + m: torch.fx.GraphModule, + example_inputs: tuple[Any, ...], + has_relu: bool, + has_bias: bool = True, + is_cuda: bool = False, + expected_conv_literal_args: Optional[tuple[Any, ...]] = None, + # TODO: set this to true by default + verify_convert: bool = False, + ): + self._verify_symmetric_xnnpack_qat_graph_helper( + m, + example_inputs, + is_per_channel=True, + has_relu=has_relu, + has_bias=has_bias, + is_cuda=is_cuda, + expected_conv_literal_args=expected_conv_literal_args, + verify_convert=verify_convert, + ) + self._verify_symmetric_xnnpack_qat_graph_helper( + m, + example_inputs, + is_per_channel=False, + has_relu=has_relu, + has_bias=has_bias, + is_cuda=is_cuda, + expected_conv_literal_args=expected_conv_literal_args, + verify_convert=verify_convert, + ) + + def _verify_symmetric_xnnpack_qat_graph_helper( + self, + m: torch.fx.GraphModule, + example_inputs: tuple[Any, ...], + is_per_channel: bool, + has_relu: bool, + has_bias: bool = True, + is_cuda: bool = False, + expected_conv_literal_args: Optional[tuple[Any, ...]] = None, + verify_convert: bool = False, + ): + """ + Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern + with fake quantizes inserted into the correct places. + # TODO: also verify that metadata is copied over to the new nodes. + """ + m = copy.deepcopy(m) + quantizer = XNNPACKQuantizer() + quantizer.set_global( + get_symmetric_quantization_config(is_per_channel, is_qat=True) + ) + m = export_for_training( + m, + example_inputs, + ).module() + m = prepare_qat_pt2e(m, quantizer) + m(*example_inputs) + + # Verify: getitem output activation fake quantize + output_node = list(m.graph.nodes)[-1] + output_fq_node = output_node.args[0][0] + self.assertTrue(output_fq_node.target.startswith("activation_post_process_")) + output_fq_mod = getattr(m, output_fq_node.target) + self.assertEqual(type(output_fq_mod), FusedMovingAvgObsFakeQuantize) + self.assertEqual( + type(output_fq_mod.activation_post_process), MovingAverageMinMaxObserver + ) + self.assertEqual(output_fq_mod.dtype, torch.int8) + self.assertEqual(output_fq_mod.quant_min, -128) + self.assertEqual(output_fq_mod.quant_max, 127) + + # Verify: getitem(bn, 0) or relu(getitem(bn, 0)) + if has_relu: + relu_node = output_fq_node.args[0] + bn_node = relu_node.args[0] + self.assertEqual(relu_node.target, torch.ops.aten.relu.default) + else: + relu_node = None + bn_node = output_fq_node.args[0] + + # The relu node takes in the output of bn. + # See NOTE [training ir has no getitem for bn node]. + self.assertEqual(bn_node.target, torch.ops.aten.batch_norm.default) + + # Verify: conv / scale_factor.reshape [+ bias.reshape] + if has_bias: + add_bias_node = bn_node.args[0] + (div_scale_factor_node, bias_reshape_node) = add_bias_node.args + self.assertEqual(add_bias_node.target, torch.ops.aten.add.Tensor) + self.assertEqual(bias_reshape_node.target, torch.ops.aten.reshape.default) + else: + div_scale_factor_node = bn_node.args[0] + (conv_node, scale_factor_reshape_node) = div_scale_factor_node.args + conv_op = conv_node.target + self.assertEqual(div_scale_factor_node.target, torch.ops.aten.div.Tensor) + self.assertTrue(_is_conv_node(conv_node)) + self.assertEqual( + scale_factor_reshape_node.target, torch.ops.aten.reshape.default + ) + + # Verify: conv literal args + if expected_conv_literal_args is not None: + assert ( + len(expected_conv_literal_args) == 6 + ), "wrong num conv args, bad test setup" + for i in range(6): + if i + 3 < len(conv_node.args): + self.assertEqual( + conv_node.args[i + 3], expected_conv_literal_args[i] + ) + + # Verify: conv input activation fake quantize + conv_input_fq_node = conv_node.args[0] + conv_input_node = conv_input_fq_node.args[0] + self.assertTrue( + conv_input_fq_node.target.startswith("activation_post_process_") + ) + conv_input_fq_mod = getattr(m, conv_input_fq_node.target) + self.assertEqual(type(conv_input_fq_mod), FusedMovingAvgObsFakeQuantize) + self.assertEqual( + type(conv_input_fq_mod.activation_post_process), MovingAverageMinMaxObserver + ) + self.assertEqual(conv_input_fq_mod.dtype, torch.int8) + self.assertEqual(conv_input_fq_mod.quant_min, -128) + self.assertEqual(conv_input_fq_mod.quant_max, 127) + self.assertTrue(conv_input_node.op, "placeholder") + + # Verify: conv weight fake quantize + conv_weight_fq_node = conv_node.args[1] + self.assertTrue( + conv_weight_fq_node.target.startswith("activation_post_process_") + ) + conv_weight_fq_mod = getattr(m, conv_weight_fq_node.target) + if is_per_channel: + expected_weight_observer_type = MovingAveragePerChannelMinMaxObserver + else: + expected_weight_observer_type = MovingAverageMinMaxObserver + self.assertEqual(type(conv_weight_fq_mod), FusedMovingAvgObsFakeQuantize) + self.assertEqual( + type(conv_weight_fq_mod.activation_post_process), + expected_weight_observer_type, + ) + self.assertEqual(conv_weight_fq_mod.dtype, torch.int8) + self.assertEqual(conv_weight_fq_mod.quant_min, -127) + self.assertEqual(conv_weight_fq_mod.quant_max, 127) + + # Verify: conv(fq(input), fq(weight * scale_factor.reshape), zero_bias) + zero_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None + mul_weight_scale_factor_node = conv_weight_fq_node.args[0] + ( + conv_weight_fq_node, + scale_factor_reshape_node, + ) = mul_weight_scale_factor_node.args + if has_bias: + self.assertEqual(zero_bias_node.target, torch.ops.aten.zeros_like.default) + else: + self.assertTrue(zero_bias_node is None) + self.assertEqual(mul_weight_scale_factor_node.target, torch.ops.aten.mul.Tensor) + self.assertEqual( + scale_factor_reshape_node.target, torch.ops.aten.reshape.default + ) + + # Verify: scale_factor = bn_weight / sqrt(bn_running_var + eps) + scale_factor_node = scale_factor_reshape_node.args[0] + (bn_weight_node, sqrt_node) = scale_factor_node.args + bn_running_var_add_node = sqrt_node.args[0] + (bn_running_var_node, eps) = bn_running_var_add_node.args + self.assertEqual(scale_factor_node.target, torch.ops.aten.div.Tensor) + self.assertTrue("bn.weight" in bn_weight_node.target) + self.assertTrue("bn.running_var" in bn_running_var_node.target) + self.assertEqual(sqrt_node.target, torch.ops.aten.sqrt.default) + self.assertEqual(bn_running_var_add_node.target, torch.ops.aten.add.Tensor) + self.assertEqual(eps, 1e-5) + + # Optionally check the converted graph + if verify_convert: + m = convert_pt2e(m) + m(*example_inputs) + + if is_per_channel: + conv_weight_dq_op = ( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ) + node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + } + else: + conv_weight_dq_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 3, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(conv_weight_dq_op), + ns.call_function(conv_op), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ] + + self.checkGraphModuleNodes( + m, + expected_node_list=node_list, + expected_node_occurrence=node_occurrence, + ) + + +class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): + """ + Base TestCase to be used for all conv-bn[-relu] fusion patterns. + """ + + # TODO: how can we avoid adding every new test to dynamo/expected_test_failures? + # Otherwise it fails with the following error: + # torch._dynamo.exc.InternalTorchDynamoError: + # 'QuantizationConfig' object has no attribute '__bool__' + + def setUp(self): + # NB: Skip the test if this is a base class, this is to handle the test + # discovery logic in buck which finds and runs all tests here including + # the base class which we don't want to run + if self.id() and "_Base" in self.id(): + self.skipTest("Skipping test running from base class") + + def test_qat_conv_no_bias(self): + m1 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=True) + m2 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=False) + self._verify_symmetric_xnnpack_qat_numerics(m1, self.example_inputs) + self._verify_symmetric_xnnpack_qat_numerics(m2, self.example_inputs) + + def test_qat_conv_bn_fusion(self): + m = self._get_conv_bn_model() + self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False) + self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_qat_conv_bn_fusion_cuda(self): + m = self._get_conv_bn_model().cuda() + example_inputs = (self.example_inputs[0].cuda(),) + self._verify_symmetric_xnnpack_qat_graph( + m, + example_inputs, + has_relu=False, + is_cuda=True, + ) + self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) + + def test_qat_conv_bn_fusion_literal_args(self): + class M(torch.nn.Module): + def __init__(self, conv_class, bn_class): + super().__init__() + self.conv = conv_class(3, 3, 3, stride=2, padding=4) + self.bn = bn_class(3) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + assert self.dim in [1, 2] + if self.dim == 1: + # stride, padding, dilation, transposed, output_padding, groups + conv_args = ((2,), (4,), (1,), False, (0,), 1) + example_inputs = (torch.randn(1, 3, 5),) + else: + # stride, padding, dilation, transposed, output_padding, groups + conv_args = ((2, 2), (4, 4), (1, 1), False, (0, 0), 1) + example_inputs = (torch.randn(1, 3, 5, 5),) + + m = M(self.conv_class, self.bn_class) + + self._verify_symmetric_xnnpack_qat_graph( + m, + example_inputs, + has_relu=False, + expected_conv_literal_args=conv_args, + ) + self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) + + def test_qat_conv_bn_fusion_no_conv_bias(self): + class M2(torch.nn.Module): + """ + Mixed conv + BN with and without conv bias. + """ + + def __init__(self, conv_class, bn_class): + super().__init__() + self.conv1 = conv_class(3, 3, 3, bias=False) + self.bn1 = bn_class(3) + self.conv2 = conv_class(3, 3, 3, bias=True) + self.bn2 = bn_class(3) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.conv2(x) + x = self.bn2(x) + return x + + m1 = self._get_conv_bn_model(has_conv_bias=False) + m2 = M2(self.conv_class, self.bn_class) + + assert self.dim in [1, 2] + if self.dim == 1: + example_inputs = (torch.randn(3, 3, 5),) + else: + example_inputs = (torch.randn(3, 3, 5, 5),) + + self._verify_symmetric_xnnpack_qat_graph( + m1, + example_inputs, + has_relu=False, + has_bias=False, + ) + self._verify_symmetric_xnnpack_qat_numerics(m1, example_inputs) + self._verify_symmetric_xnnpack_qat_numerics(m2, example_inputs) + + def test_qat_conv_bn_relu_fusion(self): + m = self._get_conv_bn_model(has_relu=True) + self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True) + self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_qat_conv_bn_relu_fusion_cuda(self): + m = self._get_conv_bn_model(has_relu=True).cuda() + example_inputs = (self.example_inputs[0].cuda(),) + self._verify_symmetric_xnnpack_qat_graph( + m, + example_inputs, + has_relu=True, + is_cuda=True, + ) + self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) + + def test_qat_conv_bn_relu_fusion_no_conv_bias(self): + m = self._get_conv_bn_model(has_conv_bias=False, has_relu=True) + self._verify_symmetric_xnnpack_qat_graph( + m, + self.example_inputs, + has_relu=True, + has_bias=False, + ) + self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) + + def test_qat_inplace_add_relu(self): + class M(torch.nn.Module): + def __init__(self, conv_class): + super().__init__() + self.conv = conv_class(1, 1, 1) + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + x0 = x + x = self.conv(x) + x += x0 + x = self.relu(x) + return x + + assert self.dim in [1, 2] + if self.dim == 1: + example_inputs = (torch.randn(1, 1, 3),) + else: + example_inputs = (torch.randn(1, 1, 3, 3),) + + m = M(self.conv_class) + self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) + + def test_qat_update_shared_qspec(self): + """ + Test the case where nodes used in SharedQuantizationSpec were replaced + during QAT subgraph rewriting. + """ + + class M(torch.nn.Module): + def __init__(self, conv_class, bn_class): + super().__init__() + self.conv = conv_class(3, 3, 3) + self.bn = bn_class(3) + self.hardtanh = torch.nn.Hardtanh() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.hardtanh(x) + return x + + m = M(self.conv_class, self.bn_class) + self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) + + def test_qat_preserve_source_fn_stack(self): + """ + Test whether `source_fn_stack` is preserved after QAT fusion. + """ + + class M(torch.nn.Module): + def __init__(self, conv_class, bn_class, backbone): + super().__init__() + self.conv = conv_class(5, 3, 3) + self.bn = bn_class(3) + self.relu = torch.nn.ReLU() + self.backbone = backbone + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + x = self.backbone(x) + return x + + assert self.dim in [1, 2] + if self.dim == 1: + example_inputs = (torch.randn(1, 5, 10),) + else: + example_inputs = (torch.randn(1, 5, 10, 10),) + + # QAT prepare + convert + backbone = self._get_conv_bn_model(has_relu=True) + m = M(self.conv_class, self.bn_class, backbone) + quantizer = XNNPACKQuantizer() + quantizer.set_global(get_symmetric_quantization_config(is_qat=True)) + m = export_for_training(m, example_inputs).module() + m = prepare_qat_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + + # Extract the conv and relu nodes (bn was folded into conv) + first_conv, first_relu, second_conv, second_relu = None, None, None, None + for n in m.graph.nodes: + if n.target == torch.ops.aten.relu.default: + if first_relu is None: + assert first_conv is None, "bad test setup" + first_relu = n + first_conv = n.args[0] + else: + assert second_conv is None, "bad test setup" + second_relu = n + second_conv = n.args[0] + + # Extract the conv weight and bias nodes + def get_conv_weight_and_bias(conv_node: torch.fx.Node): + weight_dq_node = conv_node.args[1] + qweight_node = weight_dq_node.args[0] + bias_node = conv_node.args[2] + assert isinstance(qweight_node, torch.fx.Node) + assert isinstance(bias_node, torch.fx.Node) + return (qweight_node, bias_node) + + _, first_conv_bias = get_conv_weight_and_bias(first_conv) + _, second_conv_bias = get_conv_weight_and_bias(second_conv) + + # Assert that each set of conv, conv weight, and conv bias are in the same partition + def get_source_fn(node: torch.fx.Node): + # E.g. [('l__self___backbone1_conv', )] + return node.meta["source_fn_stack"][0][0] + + # we don't preserve this is quantized weight currently since it's folded + # but user can attach "quantization_tag" to the node and it will be preserved + # self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_qweight)) + # self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_qweight)) + + self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_bias)) + self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_bias)) + + # Assert that different sets of convs and relus have different partitions + self.assertNotEqual(get_source_fn(first_conv), get_source_fn(first_relu)) + self.assertNotEqual(get_source_fn(first_conv), get_source_fn(second_conv)) + self.assertNotEqual(get_source_fn(second_conv), get_source_fn(second_relu)) + self.assertNotEqual(get_source_fn(first_relu), get_source_fn(second_relu)) + + # Assert that "backbone" exists only in the second set of conv and relu's partition + self.assertTrue("backbone" not in get_source_fn(first_conv)) + self.assertTrue("backbone" not in get_source_fn(first_relu)) + self.assertTrue("backbone" in get_source_fn(second_conv)) + self.assertTrue("backbone" in get_source_fn(second_relu)) + + def test_qat_conv_bn_bias_derived_qspec(self): + m = self._get_conv_bn_model() + example_inputs = self.example_inputs + m = export_for_training(m, example_inputs).module() + quantizer = ConvBnDerivedBiasQuantizer() + m = prepare_qat_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + m(*example_inputs) + + # Assert that both weight and bias are quantized + (conv_node, _, _) = _get_conv_bn_getitem_nodes(m) + weight_dq = conv_node.args[1] + bias_dq = conv_node.args[2] + self.assertEqual( + weight_dq.target, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ) + self.assertEqual( + bias_dq.target, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ) + weight_getattr = weight_dq.args[0] + bias_getattr = bias_dq.args[0] + self.assertEqual( + weight_getattr.op, + "get_attr", + ) + self.assertEqual( + bias_getattr.op, + "get_attr", + ) + + # Assert that bias scale = weight scale * input scale + input_dq = conv_node.args[0] + input_scale = input_dq.args[1] + bias_scale = bias_dq.args[1] + weight_scale = weight_dq.args[1] + self.assertEqual(bias_scale, input_scale * weight_scale) + + # Assert that args for the bias' quantize and dequantize ops + # are copied correctly after subgraph rewriting + (bias_qmin, bias_qmax, bias_dtype) = bias_dq.args[3:] + self.assertEqual(bias_qmin, -(2**31)) + self.assertEqual(bias_qmax, 2**31 - 1) + self.assertEqual(bias_dtype, torch.int32) + + def test_qat_per_channel_weight_custom_dtype(self): + m = self._get_conv_bn_model() + example_inputs = self.example_inputs + m = export_for_training(m, example_inputs).module() + quantizer = ConvBnInt32WeightQuantizer() + m = prepare_qat_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + m(*example_inputs) + + # Assert that conv weight is quantized per channel + (conv_node, _, _) = _get_conv_bn_getitem_nodes(m) + weight_dq = conv_node.args[1] + self.assertEqual( + weight_dq.target, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ) + weight_getattr = weight_dq.args[0] + self.assertEqual( + weight_getattr.op, + "get_attr", + ) + + # Assert that args for the weight's dequantize ops + # are copied correctly after subgraph rewriting + (dq_axis, dq_qmin, dq_qmax, dq_dtype) = weight_dq.args[3:] + self.assertEqual(dq_axis, 0) + self.assertEqual(dq_qmin, 0) + self.assertEqual(dq_qmax, 2**31 - 1) + self.assertEqual(dq_dtype, torch.int32) + + def _do_test_qat_conv_transpose_bn(self, has_relu: bool): + # Use different in/out channel sizes to test if conv weight is + # properly transposed in QAT pattern + m = self._get_conv_bn_model( + has_relu=has_relu, + transpose=True, + in_channels=3, + out_channels=5, + kernel_size=3, + ) + self._verify_symmetric_xnnpack_qat_graph( + m, + self.example_inputs, + has_relu=has_relu, + verify_convert=True, + ) + + def test_qat_conv_transpose_bn(self): + self._do_test_qat_conv_transpose_bn(has_relu=False) + + def test_qat_conv_transpose_bn_relu(self): + self._do_test_qat_conv_transpose_bn(has_relu=True) + + def test_qat_conv_bn_per_channel_weight_bias(self): + m = self._get_conv_bn_model() + example_inputs = self.example_inputs + m = export_for_training(m, example_inputs).module() + quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True) + m = prepare_qat_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + m(*example_inputs) + + # Expected graph: + # x -> q_tensor -> dq_tensor -> conv -> q_tensor -> dq_tensor -> output + # weight -> q_channel -> dq_channel / + # bias -> q_channel -> dq_channel / + + (conv_node, _, _) = _get_conv_bn_getitem_nodes(m) + conv_op = conv_node.target + conv_weight_dq_op = ( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ) + node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 2, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 2, + } + node_list = [ + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ns.call_function(conv_weight_dq_op), + ns.call_function(conv_weight_dq_op), + ns.call_function(conv_op), + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ), + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ), + ] + self.checkGraphModuleNodes( + m, + expected_node_list=node_list, + expected_node_occurrence=node_occurrence, + ) + + def test_fold_bn_erases_bn_node(self): + """ + Ensure the BN node is erased from the graph after folding + it into conv in `convert_pt2e` even in train mode. + """ + m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False) + m = export_for_training(m, self.example_inputs).module() + quantizer = XNNPACKQuantizer() + quantizer.set_global( + get_symmetric_quantization_config(is_per_channel=False, is_qat=True), + ) + m = prepare_qat_pt2e(m, quantizer) + m = convert_pt2e(m) + (conv_node, bn_node, _) = _get_conv_bn_getitem_nodes(m) + self.assertTrue(conv_node is not None) + self.assertTrue(bn_node is None) + + +@skipIfNoQNNPACK +class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base): + dim = 1 + example_inputs = (torch.randn(1, 3, 5),) + conv_class = torch.nn.Conv1d + conv_transpose_class = torch.nn.ConvTranspose1d + bn_class = torch.nn.BatchNorm1d + + +@skipIfNoQNNPACK +class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base): + dim = 2 + example_inputs = (torch.randn(1, 3, 5, 5),) + conv_class = torch.nn.Conv2d + conv_transpose_class = torch.nn.ConvTranspose2d + bn_class = torch.nn.BatchNorm2d + + +def _is_conv_node(n: torch.fx.Node): + return n.op == "call_function" and n.target in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv_transpose1d, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d, + torch.ops.aten.conv_transpose2d.input, + ] + + +def _get_conv_bn_getitem_nodes(model: torch.fx.GraphModule): + """ + Return a 3-tuple of (conv, bn, getitem) nodes from the graph. + """ + model.graph.eliminate_dead_code() + model.recompile() + conv_node = None + bn_node = None + getitem_node = None + for n in model.graph.nodes: + if _is_conv_node(n): + conv_node = n + if n.target in ( + torch.ops.aten._native_batch_norm_legit.default, + torch.ops.aten.batch_norm.default, + ): + bn_node = n + if n.target == operator.getitem: + getitem_node = n + assert conv_node is not None, "bad test setup" + return (conv_node, bn_node, getitem_node) + + +class ConvBnInt32WeightQuantizer(Quantizer): + """ + Dummy quantizer that annotates conv bn in such a way that the weights + are quantized per channel to int32. + """ + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model) + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=default_fake_quant, + ) + weight_qspec = QuantizationSpec( + dtype=torch.int32, + quant_min=0, + quant_max=2**31 - 1, + qscheme=torch.per_channel_affine, + observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + ), + ) + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + conv_node.args[0]: act_qspec, + conv_node.args[1]: weight_qspec, + }, + _annotated=True, + ) + + # See NOTE [training ir has no getitem for bn node]. + bn_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=act_qspec, + _annotated=True, + ) + return model + + def validate(self, model: torch.fx.GraphModule): + pass + + +class ConvBnDerivedBiasQuantizer(Quantizer): + """ + Dummy quantizer that annotates conv bn in such a way that the bias qparams are + derived from the conv input activation and weight qparams. + """ + + def __init__(self, is_per_channel: bool = False): + super().__init__() + self.is_per_channel = is_per_channel + + def _derive_bias_qparams_from_act_and_weight_qparams(self, obs_or_fqs): + act_scale, _ = obs_or_fqs[0].calculate_qparams() + weight_scale, _ = obs_or_fqs[1].calculate_qparams() + if self.is_per_channel: + bias_scale = act_scale * weight_scale + bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32) + else: + bias_scale = torch.tensor([act_scale * weight_scale], dtype=torch.float32) + bias_zero_point = torch.tensor([0], dtype=torch.int32) + return bias_scale, bias_zero_point + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + if self.is_per_channel: + weight_qscheme = torch.per_channel_symmetric + weight_fq = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + ) + else: + weight_qscheme = torch.per_tensor_affine + weight_fq = default_fake_quant + conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model) + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=default_fake_quant, + ) + weight_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=weight_qscheme, + observer_or_fake_quant_ctr=weight_fq, + ) + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (conv_node.args[0], conv_node), + (conv_node.args[1], conv_node), + ], + derive_qparams_fn=self._derive_bias_qparams_from_act_and_weight_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=weight_qscheme, + ch_axis=0 if self.is_per_channel else None, + ) + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + conv_node.args[0]: act_qspec, + conv_node.args[1]: weight_qspec, + conv_node.args[2]: bias_qspec, + }, + _annotated=True, + ) + + # NOTE [training ir has no getitem for bn node]. + # getitem is None when we use the training IR. It outputs + # aten.batch_norm.default, which do not need any getitem node. + # In this case, we need to annotate on the batch norm node. + # geteitem node should only be None if we are using training IR. + + bn_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=act_qspec, + _annotated=True, + ) + return model + + def validate(self, model: torch.fx.GraphModule): + pass + + +@skipIfNoQNNPACK +class TestQuantizePT2EQATModels(PT2EQATTestCase): + @skip_if_no_torchvision + @skipIfNoQNNPACK + def test_qat_resnet18(self): + import torchvision + + with override_quantized_engine("qnnpack"): + example_inputs = (torch.randn(1, 3, 224, 224),) + m = torchvision.models.resnet18() + self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) + + @skip_if_no_torchvision + @skipIfNoQNNPACK + def test_qat_mobilenet_v2(self): + import torchvision + + with override_quantized_engine("qnnpack"): + example_inputs = (torch.randn(1, 3, 224, 224),) + m = torchvision.models.mobilenet_v2() + self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) + + +class TestQuantizeMixQATAndPTQ(QuantizationTestCase): + class TwoLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(16, 8, bias=False) + self.linear2 = torch.nn.Linear(8, 8) + + def forward(self, x): + return self.linear2(self.linear1(x)) + + class QATPTQTestModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3) + self.linears = TestQuantizeMixQATAndPTQ.TwoLinear() + self.my_linear = torch.nn.Linear(8, 8) + + def forward(self, x): + conv_out = self.conv(x) + permute_out = torch.permute(conv_out, (0, 2, 3, 1)) + linear_out = self.linears(permute_out) + my_linear_out = self.my_linear(linear_out) + # Hardtanh doesnt get quantized via xnnpack quantizer in this test + # because it relies on the propagation rules + # Need to fix this + return torch.nn.functional.hardtanh(my_linear_out) + + def _prepare_qat_linears(self, model): + for name, child in model.named_children(): + if isinstance(child, (torch.nn.Linear, TestQuantizeMixQATAndPTQ.TwoLinear)): + if isinstance(child, torch.nn.Linear): + in_channels = child.weight.size(1) + else: + in_channels = child.linear1.weight.size(1) + + example_input = (torch.rand((1, in_channels)),) + traced_child = export_for_training(child, example_input).module() + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=True, is_qat=True + ) + quantizer.set_global(quantization_config) + traced_child_prepared = prepare_qat_pt2e(traced_child, quantizer) + setattr(model, name, traced_child_prepared) + else: + self._prepare_qat_linears(child) + + def _convert_qat_linears(self, model): + for name, child in model.named_children(): + if isinstance(child, torch.fx.GraphModule): + torchao.quantization.pt2e_flow.move_exported_model_to_eval(child) + converted_child = convert_pt2e(child) + setattr(model, name, converted_child) + else: + self._convert_qat_linears(child) + + def test_mixing_qat_ptq(self): + example_inputs = (torch.randn(2, 3, 4, 4),) + model = TestQuantizeMixQATAndPTQ.QATPTQTestModule() + + self._prepare_qat_linears(model) + + model(*example_inputs) + # must be fixed model.eval() + self._convert_qat_linears(model) + model(*example_inputs) + + model_pt2e = export_for_training( + model, + example_inputs, + ).module() + + quantizer = XNNPACKQuantizer() + quantizer.set_module_type(torch.nn.Linear, None) + quantization_config = get_symmetric_quantization_config() + quantizer.set_global(quantization_config) + model_pt2e = prepare_pt2e(model_pt2e, quantizer) + after_prepare_result_pt2e = model_pt2e(*example_inputs) # noqa: F841 + model_pt2e = convert_pt2e(model_pt2e) + quant_result_pt2e = model_pt2e(*example_inputs) # noqa: F841 + + exported_model = torch.export.export(model_pt2e, example_inputs, strict=True) + + node_occurrence = { + # conv2d: 1 for act, 1 for weight, 1 for output + # 3 x linear: 1 for act, 1 for output + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 8, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 9, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 3, + # There needs to be one for hardtanh + } + self.checkGraphModuleNodes( + exported_model.graph_module, expected_node_occurrence=node_occurrence + ) diff --git a/test/quantization/pt2e_flow/test_representation.py b/test/quantization/pt2e_flow/test_representation.py new file mode 100644 index 0000000000..75a8b50906 --- /dev/null +++ b/test/quantization/pt2e_flow/test_representation.py @@ -0,0 +1,314 @@ +# Owner(s): ["oncall: quantization"] +import copy +from typing import Any, Optional + +import torch +from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, +) +from torch.testing._internal.common_quantization import ( + QuantizationTestCase, + TestHelperModules, + skipIfNoQNNPACK, +) + +from torchao.quantization.pt2e_flow.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e_flow.quantizer import Quantizer +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, +) + + +@skipIfNoQNNPACK +class TestPT2ERepresentation(QuantizationTestCase): + def _test_representation( + self, + model: torch.nn.Module, + example_inputs: tuple[Any, ...], + quantizer: Quantizer, + ref_node_occurrence: dict[ns, int], + non_ref_node_occurrence: dict[ns, int], + fixed_output_tol: Optional[float] = None, + output_scale_idx: int = 2, + ) -> torch.nn.Module: + # resetting dynamo cache + torch._dynamo.reset() + model = export_for_training( + model, + example_inputs, + ).module() + model_copy = copy.deepcopy(model) + + model = prepare_pt2e(model, quantizer) + # Calibrate + model(*example_inputs) + model = convert_pt2e(model, use_reference_representation=True) + self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence) + # make sure it runs + pt2e_quant_output = model(*example_inputs) + + # TODO: torchdynamo times out when we do this, we can enable numerical checking + # after that is fixed + model_copy = prepare_pt2e(model_copy, quantizer) + # Calibrate + model_copy(*example_inputs) + model_copy = convert_pt2e(model_copy, use_reference_representation=False) + self.checkGraphModuleNodes( + model_copy, expected_node_occurrence=non_ref_node_occurrence + ) + pt2e_quant_output_copy = model_copy(*example_inputs) + + output_tol = None + if fixed_output_tol is not None: + output_tol = fixed_output_tol + else: + idx = 0 + for n in model_copy.graph.nodes: + if ( + n.target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ): + idx += 1 + if idx == output_scale_idx: + output_tol = n.args[1] + assert output_tol is not None + + # make sure the result is off by one at most in the quantized integer representation + self.assertTrue( + torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) + <= (2 * output_tol + 1e-5) + ) + + def test_static_linear(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 5),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_dynamic_linear(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=True + ) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 5),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + fixed_output_tol=1e-4, + ) + + def test_conv2d(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv2d = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + return self.conv2d(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(1, 3, 3, 3),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_add(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + return x + y + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + M().eval() + + example_inputs = ( + torch.randn(1, 3, 3, 3), + torch.randn(1, 3, 3, 3), + ) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_add_relu(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + out = x + y + out = torch.nn.functional.relu(out) + return out + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + + example_inputs = ( + torch.randn(1, 3, 3, 3), + torch.randn(1, 3, 3, 3), + ) + ref_node_occurrence = { + ns.call_function(out_dtype): 2, + } + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence=ref_node_occurrence, + non_ref_node_occurrence={}, + ) + + def test_maxpool2d(self): + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + m_eager = TestHelperModules.ConvMaxPool2d().eval() + + example_inputs = (torch.randn(1, 2, 2, 2),) + + self._test_representation( + m_eager, + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={}, + ) + + def test_qdq_per_channel(self): + """Test representation for quantize_per_channel and dequantize_per_channel op""" + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + # use per channel quantization for weight + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + M().eval() + + inputs = [ + (torch.randn(1, 5),), + (torch.randn(1, 3, 5),), + (torch.randn(1, 3, 3, 5),), + (torch.randn(1, 3, 3, 3, 5),), + ] + for example_inputs in inputs: + ref_node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel.default + ): 0, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 0, + } + non_ref_node_occurrence = { + # quantize_per_channel is folded + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel.default + ): 0, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 1, + } + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence, + non_ref_node_occurrence, + output_scale_idx=2, + ) + + def test_qdq(self): + """Test representation for quantize and dequantize op""" + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + return x + y + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + M().eval() + + example_inputs = ( + torch.randn(1, 3, 3, 3), + torch.randn(1, 3, 3, 3), + ) + ref_node_occurrence = { + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 0, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 0, + } + non_ref_node_occurrence = { + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 3, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 3, + } + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence, + non_ref_node_occurrence, + ) diff --git a/test/quantization/pt2e_flow/test_x86inductor_quantizer.py b/test/quantization/pt2e_flow/test_x86inductor_quantizer.py new file mode 100644 index 0000000000..4d8b25ce0b --- /dev/null +++ b/test/quantization/pt2e_flow/test_x86inductor_quantizer.py @@ -0,0 +1,2737 @@ +# Owner(s): ["oncall: quantization"] +import copy +import itertools +from enum import Enum + +import torch +import torch.nn as nn +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, +) +from torch.testing._internal.common_quantization import ( + QuantizationTestCase, + skipIfNoInductorSupport, + skipIfNoX86, +) +from torch.testing._internal.common_quantized import override_quantized_engine +from torch.testing._internal.common_utils import skipIfTorchDynamo + +import torchao.quantization.pt2e_flow.quantizer.x86_inductor_quantizer as xiq +from torchao.quantization.pt2e_flow import ObserverBase +from torchao.quantization.pt2e_flow.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torchao.quantization.pt2e_flow.quantizer.x86_inductor_quantizer import ( + QUANT_ANNOTATION_KEY, + X86InductorQuantizer, +) + + +class NodePosType(Enum): + left = 1 + right = 2 + both = 3 + + +class TestHelperModules: + class SingleConv2dModule(torch.nn.Module): + def __init__(self, with_bn=False) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1)) + self.bn = torch.nn.BatchNorm2d(6) + self.with_bn = with_bn + + def forward(self, x): + x = self.conv(x) + if self.with_bn: + x = self.bn(x) + return x + + class Conv2dUnaryModule(torch.nn.Module): + def __init__(self, post_op, use_bias: bool = False, with_bn=False) -> None: + super().__init__() + self.conv = nn.Conv2d( + 3, 6, (2, 2), stride=(1, 1), padding=(1, 1), bias=use_bias + ) + self.post_op = post_op + self.bn = torch.nn.BatchNorm2d(6) + self.with_bn = with_bn + self.maxpool = torch.nn.MaxPool2d((3, 3)) + + def forward(self, x): + x = self.conv(x) + if self.with_bn: + x = self.bn(x) + x = self.post_op(x) + x = self.maxpool(x) + return x + + class Conv2dAddModule(torch.nn.Module): + def __init__( + self, + inplace_add: bool = False, + conv2d_type: NodePosType = NodePosType.left, + use_bias: bool = False, + with_bn: bool = False, + ) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + ) + self.conv2 = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + ) + self.relu = nn.ReLU() + self.inplace_add = inplace_add + self.conv2d_type = conv2d_type + self.bn = torch.nn.BatchNorm2d(3) + self.with_bn = with_bn + + def forward(self, x): + if self.conv2d_type == NodePosType.left: + if self.inplace_add: + tmp = self.conv(x) + if self.with_bn: + tmp = self.bn(tmp) + tmp += self.relu(x) + return tmp + else: + tmp = self.conv(x) + if self.with_bn: + tmp = self.bn(tmp) + return tmp + self.relu(x) + elif self.conv2d_type == NodePosType.right: + if self.inplace_add: + tmp = self.relu(x) + tmp += self.conv(x) + return tmp + else: + return self.relu(x) + self.conv(x) + elif self.conv2d_type == NodePosType.both: + if self.inplace_add: + tmp = self.conv(x) + tmp += self.conv2(x) + return tmp + else: + return self.conv(x) + self.conv2(x) + + class Conv2dAddReLUModule(torch.nn.Module): + def __init__( + self, + inplace_add: bool = False, + conv2d_type: NodePosType = NodePosType.left, + inplace_relu: bool = False, + use_bias: bool = False, + with_bn: bool = False, + ) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + ) + self.conv2 = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + ) + self.relu = nn.ReLU() + self.inplace_add = inplace_add + self.conv2d_type = conv2d_type + self.relu2 = nn.ReLU(inplace=inplace_relu) + self.bn = torch.nn.BatchNorm2d(3) + self.with_bn = with_bn + + def forward(self, x): + if self.conv2d_type == NodePosType.left: + if self.inplace_add: + tmp = self.conv(x) + if self.with_bn: + tmp = self.bn(tmp) + tmp += self.relu(x) + return self.relu2(tmp) + else: + tmp = self.conv(x) + if self.with_bn: + tmp = self.bn(tmp) + return self.relu2(tmp + self.relu(x)) + elif self.conv2d_type == NodePosType.right: + if self.inplace_add: + tmp = self.relu(x) + tmp += self.conv(x) + return self.relu2(tmp) + else: + return self.relu2(self.relu(x) + self.conv(x)) + elif self.conv2d_type == NodePosType.both: + if self.inplace_add: + tmp = self.conv(x) + tmp += self.conv2(x) + return self.relu2(tmp) + else: + return self.relu2(self.conv(x) + self.conv2(x)) + + class Conv2dSingleOpPowModule(nn.Module): + def __init__(self, single_op): + super().__init__() + self.conv = nn.Conv2d(2, 2, 1) + self.single_op = single_op + + def forward(self, x): + x = self.conv(x) + x = self.single_op(x) + return torch.pow(x, 2) + + class SerialsConv2dAddReLUModule(torch.nn.Module): + """Serials of 2 Conv2d -> Add -> ReLU Pattern.""" + + def __init__( + self, + ) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.conv2 = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.conv3 = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.conv4 = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.relu = nn.ReLU() + self.relu2 = nn.ReLU() + + def forward(self, x): + x1 = self.conv(x) + res1 = self.relu(self.conv2(x1) + self.conv3(x1)) + res2 = self.relu2(self.conv4(res1) + res1) + return res2 + + class Conv2dCatMaxpool2d(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 + ) + self.conv2 = torch.nn.Conv2d( + 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 + ) + self.relu = torch.nn.ReLU() + self.maxpool = torch.nn.MaxPool2d(3, stride=2, padding=1) + self.conv3 = torch.nn.Conv2d( + 32, 32, 7, bias=True, stride=2, padding=3, dilation=1 + ) + + def forward(self, x): + temp1 = self.relu(self.conv(x)) + temp2 = self.conv2(x + 1) + temp3 = torch.cat((temp1, temp2), 1) + temp4 = self.maxpool(temp3) + temp5 = self.conv3(temp4) + return temp5 + + class Conv2dAvgPool2d(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 + ) + self.avgpool = torch.nn.AvgPool2d(3, stride=2, padding=1) + + def forward(self, x): + temp1 = self.avgpool(self.conv(x)) + return temp1 + + class Conv2dCatSameInputs(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 + ) + self.relu = torch.nn.ReLU() + + def forward(self, x): + temp1 = self.relu(self.conv(x)) + temp3 = torch.cat((temp1, temp1), 1) + return temp3 + + class Conv2dCatSingleInput(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 + ) + self.relu = torch.nn.ReLU() + + def forward(self, x): + temp1 = self.relu(self.conv(x)) + temp3 = torch.cat((temp1,), 1) + return temp3 + + class SingleLinearModule(torch.nn.Module): + def __init__(self, use_bias) -> None: + super().__init__() + self.linear = nn.Linear(4, 4, bias=use_bias) + + def forward(self, x): + return self.linear(x) + + class LinearUnaryModule(torch.nn.Module): + def __init__( + self, use_bias, postop, inplace_postop=False, post_op_algo="none" + ) -> None: + super().__init__() + self.linear = nn.Linear(4, 4, bias=use_bias) + if postop == nn.GELU: + self.postop = postop(approximate=post_op_algo) + else: + self.postop = postop(inplace=inplace_postop) + + def forward(self, x): + return self.postop(self.linear(x)) + + class LinearAddModule(torch.nn.Module): + def __init__( + self, + inplace_add: bool = False, + linear_pos: NodePosType = NodePosType.left, + use_bias: bool = False, + ) -> None: + super().__init__() + self.linear = torch.nn.Linear( + in_features=16, out_features=16, bias=use_bias + ) + self.linear2 = torch.nn.Linear( + in_features=16, out_features=16, bias=use_bias + ) + self.relu = nn.ReLU() + self.inplace_add = inplace_add + self.linear_pos = linear_pos + + def forward(self, x): + if self.linear_pos == NodePosType.left: + if self.inplace_add: + tmp = self.linear(x) + tmp += self.relu(x) + return tmp + else: + tmp = self.linear(x) + return tmp + self.relu(x) + elif self.linear_pos == NodePosType.right: + if self.inplace_add: + tmp = self.relu(x) + tmp += self.linear(x) + return tmp + else: + return self.relu(x) + self.linear(x) + elif self.linear_pos == NodePosType.both: + if self.inplace_add: + tmp = self.linear(x) + tmp += self.linear2(x) + return tmp + else: + return self.linear(x) + self.linear2(x) + + class LinearAddReLUModule(torch.nn.Module): + def __init__( + self, + inplace_add: bool = False, + linear_pos: NodePosType = NodePosType.left, + inplace_relu: bool = False, + use_bias: bool = False, + ) -> None: + super().__init__() + self.linear = torch.nn.Linear( + in_features=16, out_features=16, bias=use_bias + ) + self.linear2 = torch.nn.Linear( + in_features=16, out_features=16, bias=use_bias + ) + self.relu = nn.ReLU() + self.inplace_add = inplace_add + self.linear_pos = linear_pos + self.relu2 = nn.ReLU(inplace=inplace_relu) + + def forward(self, x): + if self.linear_pos == NodePosType.left: + if self.inplace_add: + tmp = self.linear(x) + tmp += self.relu(x) + return self.relu2(tmp) + else: + tmp = self.linear(x) + return self.relu2(tmp + self.relu(x)) + elif self.linear_pos == NodePosType.right: + if self.inplace_add: + tmp = self.relu(x) + tmp += self.linear(x) + return self.relu2(tmp) + else: + return self.relu2(self.relu(x) + self.linear(x)) + elif self.linear_pos == NodePosType.both: + if self.inplace_add: + tmp = self.linear(x) + tmp += self.linear2(x) + return self.relu2(tmp) + else: + return self.relu2(self.linear(x) + self.linear2(x)) + + class SerialsLinearAddReLUModule(torch.nn.Module): + """Serials of 2 Linear -> Add -> ReLU Pattern.""" + + def __init__( + self, + ) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True) + self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True) + self.linear3 = torch.nn.Linear(in_features=16, out_features=16, bias=True) + self.linear4 = torch.nn.Linear(in_features=16, out_features=16, bias=True) + self.relu = nn.ReLU() + self.relu2 = nn.ReLU() + + def forward(self, x): + x1 = self.linear(x) + res1 = self.relu(self.linear2(x1) + self.linear3(x1)) + res2 = self.relu2(self.linear4(res1) + res1) + return res2 + + class LinearAddModule2(torch.nn.Module): + def __init__( + self, + inplace_add: bool = False, + ) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True) + self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True) + self.inplace_add = inplace_add + + def forward(self, x): + if self.inplace_add: + tmp = self.linear(x) + tmp += self.linear2(tmp) + return tmp + else: + tmp = self.linear(x) + return tmp + self.linear2(tmp) + + class Conv2dAddModule2(torch.nn.Module): + def __init__( + self, + inplace_add: bool = False, + ) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ) + self.conv2 = torch.nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ) + self.inplace_add = inplace_add + self.bn = torch.nn.BatchNorm2d(3) + self.bn2 = torch.nn.BatchNorm2d(3) + + def forward(self, x): + if self.inplace_add: + tmp = self.bn(self.conv(x)) + tmp += self.bn2(self.conv2(tmp)) + return tmp + else: + tmp = self.bn(self.conv(x)) + return tmp + self.bn2(self.conv2(tmp)) + + class SelfAttnLikeModule(torch.nn.Module): + def __init__( + self, + input_dim, + transpose_for_score=False, + num_attention_heads=None, + attention_head_size=None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.q_proj = nn.Linear(input_dim, input_dim, bias=False) + self.k_proj = nn.Linear(input_dim, input_dim, bias=False) + self.v_proj = nn.Linear(input_dim, input_dim, bias=False) + self.softmax = nn.Softmax(dim=-1) + self.transpose_for_score = transpose_for_score + if self.transpose_for_score: + assert num_attention_heads is not None + assert attention_head_size is not None + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, x): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + if self.transpose_for_score: + q = self.transpose_for_scores(q) + k = self.transpose_for_scores(k) + v = self.transpose_for_scores(v) + scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) + attention = self.softmax(scores) + weighted = torch.matmul(attention, v) + return weighted + + class Conv2dFlattenTranspose(nn.Module): + def __init__(self): + super().__init__() + self.projection = torch.nn.Conv2d( + 3, 768, kernel_size=(16, 16), stride=(16, 16) + ) + self.cls_token = torch.rand(1, 1, 768) + + def forward(self, pixel_values): + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + embeddings = torch.cat((self.cls_token, embeddings), dim=1) + return embeddings + + class Conv2dFlattenCatTranspose(nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16)) + + def forward(self, x): + y = self.conv(x).flatten(2) + y = torch.cat([y, y], dim=-1) + return y.transpose(1, 2) + + +class X86InductorQuantTestCase(QuantizationTestCase): + def _test_quantizer( + self, + model, + example_inputs, + quantizer, + expected_node_occurrence, + expected_node_list=None, + is_qat=False, + debug=False, + ): + m_eager = model.train() if is_qat else model.eval() + + # program capture + m = copy.deepcopy(m_eager) + m = export_for_training( + m, + example_inputs, + ).module() + + # QAT Model failed to deepcopy + export_model = m if is_qat else copy.deepcopy(m) + m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer) + # Calibrate + m(*example_inputs) + prepare_model = copy.deepcopy(m) + m = convert_pt2e(m) + convert_model = copy.deepcopy(m) + if debug: + convert_model.print_readable(True) + m(*example_inputs) + node_occurrence = { + ns.call_function(k): v for k, v in expected_node_occurrence.items() + } + if expected_node_list is None: + expected_node_list = [] + node_list = [ns.call_function(n) for n in expected_node_list] + self.checkGraphModuleNodes( + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list + ) + return export_model, prepare_model, convert_model + + +@skipIfNoInductorSupport +class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): + @skipIfNoX86 + def test_conv2d(self): + """ + Test pattern of single conv2d with X86InductorQuantizer. + """ + with override_quantized_engine("x86"), torch.no_grad(): + m = TestHelperModules.SingleConv2dModule().eval() + example_inputs = (torch.randn(2, 3, 16, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + node_occurrence = { + # one for input and weight of the conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_conv2d_unary(self): + """ + Test pattern of conv2d with unary post ops (such as relu, hardtanh, hardswish, relu6) with X86InductorQuantizer. + """ + unary_map = { + "relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default], + "relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default], + "hardtanh": [ + torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), + torch.ops.aten.hardtanh.default, + ], + "hardtanh_inplace": [ + torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), + torch.ops.aten.hardtanh_.default, + ], + "relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default], + "relu6_inplace": [ + torch.nn.ReLU6(inplace=True), + torch.ops.aten.hardtanh_.default, + ], + "hardswish": [ + torch.nn.Hardswish(inplace=False), + torch.ops.aten.hardswish.default, + ], + "hardswish_inplace": [ + torch.nn.Hardswish(inplace=True), + torch.ops.aten.hardswish_.default, + ], + "swish": [torch.nn.SiLU(inplace=False), torch.ops.aten.silu.default], + "swish_inplace": [ + torch.nn.SiLU(inplace=True), + torch.ops.aten.silu_.default, + ], + } + use_bias_list = [True, False] + with override_quantized_engine("x86"), torch.no_grad(): + for unary_op, use_bias in itertools.product( + unary_map.keys(), use_bias_list + ): + m = TestHelperModules.Conv2dUnaryModule( + unary_map[unary_op][0], use_bias=use_bias + ).eval() + example_inputs = (torch.randn(2, 3, 16, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + node_occurrence = { + # one for input and weight of the conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + unary_map[unary_op][1], + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_conv2d_binary(self): + """ + Test pattern of conv2d with binary post ops (such as add) with X86InductorQuantizer. + Currently, only add as binary post op is supported. + """ + conv2d_type_list = [NodePosType.left, NodePosType.both] + example_inputs = (torch.randn(2, 3, 6, 6),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + with override_quantized_engine("x86"), torch.no_grad(): + for conv2d_type in conv2d_type_list: + m = TestHelperModules.Conv2dAddModule(conv2d_type=conv2d_type).eval() + if conv2d_type != NodePosType.both: + node_occurrence = { + # one for input and weight of the conv + # one for extra input node of add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + else: + node_occurrence = { + # one for input of the conv + # one for input of another conv + # 2 conv will share same input quant/dequant + # one for extra input node of add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.add.Tensor, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_conv2d_binary2(self): + """ + Test Pattern: + tmp = conv2d_1(x) + tmp2 = conv2d_2(tmp) + return tmp + tmp2 + Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1 + """ + example_inputs = (torch.randn(2, 3, 6, 6),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + inplace_add_list = [True, False] + with override_quantized_engine("x86"), torch.no_grad(): + for inplace_add in inplace_add_list: + m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add).eval() + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_conv2d_binary_unary(self): + """ + Test pattern of conv2d with binary + unary post ops (such as add + relu) with X86InductorQuantizer. + Currently, only add as binary post op and relu as unary post op are supported. + """ + conv2d_type_list = [NodePosType.left, NodePosType.both] + example_inputs = (torch.randn(2, 3, 6, 6),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + with override_quantized_engine("x86"), torch.no_grad(): + for conv2d_type in conv2d_type_list: + m = TestHelperModules.Conv2dAddReLUModule( + conv2d_type=conv2d_type, + ).eval() + if conv2d_type != NodePosType.both: + node_occurrence = { + # one for input for conv + # one for extra input node of add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + else: + node_occurrence = { + # one for input of the conv + # one for input of another conv + # 2 conv will share same input quant/dequant + # one for extra input node of add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.add.Tensor, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_conv2d_serials_binary_unary(self): + """ + Test pattern of 2 following up conv2d add relu with X86InductorQuantizer. + """ + with override_quantized_engine("x86"), torch.no_grad(): + m = TestHelperModules.SerialsConv2dAddReLUModule().eval() + example_inputs = (torch.randn(2, 3, 16, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 6, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.add.Tensor, + torch.ops.aten.relu.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def _single_op_share_observer_recipe_test_helper(self, m, x, single_op): + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (x,) + node_occurrence = { + # one for input and weight of the conv, two for input/output for the maxpool2d + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + single_op, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + _, prepare_model, _ = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + # Check Maxpool2d has share observer at input and output + for node in prepare_model.graph.nodes: + if node.op == "call_function" and node.target is single_op: + single_op_node = node + input_obs_of_single_op = getattr( + prepare_model, single_op_node.args[0].target + ) + output_obs_of_single_op = getattr( + prepare_model, next(iter(single_op_node.users)).target + ) + elif ( + node.op == "call_function" + and node.target is torch.ops.aten.conv2d.default + ): + conv_node = node + input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target) + self.assertTrue(isinstance(input_obs_of_single_op, ObserverBase)) + self.assertTrue(isinstance(output_obs_of_single_op, ObserverBase)) + self.assertTrue(isinstance(input_obs_of_conv, ObserverBase)) + self.assertTrue(input_obs_of_single_op is output_obs_of_single_op) + self.assertTrue(input_obs_of_single_op is not input_obs_of_conv) + + @skipIfNoX86 + def test_maxpool2d_recipe(self): + r""" + Test pattern: int8_in_int8_out_ops(maxpool) - non_quantizable op(pow) + Since maxpool is a int8_in_int8_out_op, there is obs between maxpool and pow. + """ + self._single_op_share_observer_recipe_test_helper( + TestHelperModules.Conv2dSingleOpPowModule(nn.MaxPool2d(1, 1)).eval(), + torch.rand(1, 2, 14, 14), + torch.ops.aten.max_pool2d.default, + ) + + @skipIfNoX86 + def test_adaptive_avg_pool2d_recipe(self): + r""" + Test pattern: int8_in_int8_out_ops(adaptive_avg_pool2d) - non_quantizable op(pow) + Since adaptive_avg_pool2d is a int8_in_int8_out_op, there is obs between adaptive_avg_pool2d and pow. + """ + self._single_op_share_observer_recipe_test_helper( + TestHelperModules.Conv2dSingleOpPowModule( + nn.AdaptiveAvgPool2d((1, 1)) + ).eval(), + torch.rand(1, 2, 14, 14), + torch.ops.aten.adaptive_avg_pool2d.default, + ) + + @skipIfNoX86 + def test_flatten_recipe(self): + r""" + Test pattern: conv -> flatten -> cat -> transpose + """ + m = TestHelperModules.Conv2dFlattenCatTranspose().eval() + x = torch.randn(1, 3, 224, 224) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (x,) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.flatten.using_ints, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.cat.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + _, prepare_model, _ = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + # Check Flatten has share observer at input and output + for node in prepare_model.graph.nodes: + if ( + node.op == "call_function" + and node.target is torch.ops.aten.flatten.using_ints + ): + single_op_node = node + input_obs_of_single_op = getattr( + prepare_model, single_op_node.args[0].target + ) + output_obs_of_single_op = getattr( + prepare_model, next(iter(single_op_node.users)).target + ) + elif ( + node.op == "call_function" + and node.target is torch.ops.aten.conv2d.default + ): + conv_node = node + input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target) + self.assertTrue(isinstance(input_obs_of_single_op, ObserverBase)) + self.assertTrue(isinstance(output_obs_of_single_op, ObserverBase)) + self.assertTrue(isinstance(input_obs_of_conv, ObserverBase)) + self.assertTrue(input_obs_of_single_op is output_obs_of_single_op) + self.assertTrue(input_obs_of_single_op is not input_obs_of_conv) + + @skipIfNoX86 + def test_flatten_recipe2(self): + r""" + Test pattern: conv -> flatten -> transpose + """ + m = TestHelperModules.Conv2dFlattenTranspose().eval() + x = torch.randn(1, 3, 224, 224) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (x,) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.flatten.using_ints, + torch.ops.aten.transpose.int, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_cat_recipe(self): + r""" + Test pattern: conv -> cat -> maxpool2d + Since cat, maxpool is a int8_in_int8_out_op, the inputs and outputs should with same observer. + """ + m = TestHelperModules.Conv2dCatMaxpool2d().eval() + x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (x,) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 6, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 6, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.cat.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.max_pool2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + _, prepare_model, _ = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + # Check Cat/Maxpool2d has share observer at input and output + for node in prepare_model.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.cat.default: + cat_act_obs0 = getattr(prepare_model, node.all_input_nodes[0].target) + cat_act_obs1 = getattr(prepare_model, node.all_input_nodes[1].target) + cat_out_obs = getattr(prepare_model, next(iter(node.users)).target) + elif ( + node.op == "call_function" + and node.target is torch.ops.aten.max_pool2d.default + ): + maxpool_node = node + input_obs_of_maxpool = getattr( + prepare_model, maxpool_node.args[0].target + ) + output_obs_of_maxpool = getattr( + prepare_model, next(iter(maxpool_node.users)).target + ) + self.assertTrue(isinstance(cat_act_obs0, ObserverBase)) + self.assertTrue(isinstance(cat_act_obs1, ObserverBase)) + self.assertTrue(isinstance(cat_out_obs, ObserverBase)) + self.assertTrue(isinstance(input_obs_of_maxpool, ObserverBase)) + self.assertTrue(isinstance(output_obs_of_maxpool, ObserverBase)) + self.assertTrue(cat_act_obs0 is cat_act_obs1) + self.assertTrue(cat_act_obs0 is cat_out_obs) + self.assertTrue(cat_out_obs is input_obs_of_maxpool) + self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool) + + @skipIfNoX86 + def test_cat_recipe_same_inputs(self): + r""" + Test pattern: conv -> cat([input0, input0]) + Since cat has 2 input node of same tensor, they should also be with same observer. + """ + m = TestHelperModules.Conv2dCatSameInputs().eval() + x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (x,) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.cat.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + _, prepare_model, _ = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + # Check Cat has share observer at input and output + for node in prepare_model.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.cat.default: + cat_act_obs0 = getattr(prepare_model, node.args[0][0].target) + cat_act_obs1 = getattr(prepare_model, node.args[0][1].target) + cat_out_obs = getattr(prepare_model, next(iter(node.users)).target) + self.assertTrue(isinstance(cat_act_obs0, ObserverBase)) + self.assertTrue(isinstance(cat_act_obs1, ObserverBase)) + self.assertTrue(isinstance(cat_out_obs, ObserverBase)) + self.assertTrue(cat_act_obs0 is cat_act_obs1) + self.assertTrue(cat_act_obs0 is cat_out_obs) + + @skipIfNoX86 + def test_cat_recipe_single_input(self): + r""" + Test pattern: conv -> cat([input0,]) + Since cat has 1 input node, they should also be with same observer. + """ + m = TestHelperModules.Conv2dCatSingleInput().eval() + x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (x,) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.cat.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + _, prepare_model, _ = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + # Check Cat has share observer at input and output + for node in prepare_model.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.cat.default: + cat_act_obs0 = getattr(prepare_model, node.args[0][0].target) + cat_out_obs = getattr(prepare_model, next(iter(node.users)).target) + self.assertTrue(isinstance(cat_act_obs0, ObserverBase)) + self.assertTrue(isinstance(cat_out_obs, ObserverBase)) + self.assertTrue(cat_act_obs0 is cat_out_obs) + + @skipIfNoX86 + def test_avg_pool2d_recipe(self): + r""" + Test pattern: conv -> AvgPool2d + Since AvgPool2d is a int8_in_int8_out_op, the inputs and outputs should with same observer. + """ + m = TestHelperModules.Conv2dAvgPool2d().eval() + x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (x,) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.avg_pool2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + _, prepare_model, _ = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + for node in prepare_model.graph.nodes: + if ( + node.op == "call_function" + and node.target is torch.ops.aten.avg_pool2d.default + ): + avgpool_node = node + input_obs_of_avgpool = getattr( + prepare_model, avgpool_node.args[0].target + ) + output_obs_of_avgpool = getattr( + prepare_model, next(iter(avgpool_node.users)).target + ) + elif ( + node.op == "call_function" + and node.target is torch.ops.aten.conv2d.default + ): + conv_node = node + output_obs_of_conv = getattr( + prepare_model, next(iter(conv_node.users)).target + ) + self.assertTrue(isinstance(input_obs_of_avgpool, ObserverBase)) + self.assertTrue(isinstance(output_obs_of_avgpool, ObserverBase)) + self.assertTrue(isinstance(output_obs_of_conv, ObserverBase)) + self.assertTrue(input_obs_of_avgpool is output_obs_of_avgpool) + self.assertTrue(input_obs_of_avgpool is output_obs_of_conv) + + @skipIfNoX86 + def test_linear(self): + """ + Test pattern of single linear with X86InductorQuantizer. + """ + with override_quantized_engine("x86"), torch.no_grad(): + for use_bias in [True, False]: + m = TestHelperModules.SingleLinearModule(use_bias).eval() + example_inputs = (torch.randn(2, 4),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + node_occurrence = { + # one for input and weight, one for output + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def _test_linear_unary_helper( + self, + post_op_module, + post_op_aten, + post_op_aten_inplace, + post_op_algo_list=None, + is_qat=False, + is_dynamic=False, + ): + """ + Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer. + """ + use_bias_list = [True, False] + # TODO test for inplace add after refactoring of export_for_training + inplace_list = [False] + if post_op_algo_list is None: + post_op_algo_list = [None] + cases = itertools.product(use_bias_list, inplace_list, post_op_algo_list) + with override_quantized_engine("x86"), torch.no_grad(): + for use_bias, inplace, post_op_algo in cases: + if inplace and post_op_aten_inplace is None: + continue + m = TestHelperModules.LinearUnaryModule( + use_bias=use_bias, + postop=post_op_module, + inplace_postop=inplace, + post_op_algo=post_op_algo, + ).eval() + example_inputs = (torch.randn(2, 4),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + ) + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + node_occurrence = { + # one for input of the linear + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + quantize_per_tensor_op, + dequantize_per_tensor_op, + torch.ops.aten.linear.default, + post_op_aten_inplace if inplace else post_op_aten, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=is_qat, + ) + + @skipIfNoX86 + def test_linear_unary(self): + aten = torch.ops.aten + self._test_linear_unary_helper(nn.ReLU, aten.relu.default, aten.relu_.default) + self._test_linear_unary_helper( + nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default + ) + self._test_linear_unary_helper( + nn.GELU, aten.gelu.default, None, ["none", "tanh"] + ) + + @skipIfNoX86 + def test_linear_unary_qat(self): + aten = torch.ops.aten + self._test_linear_unary_helper( + nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True + ) + self._test_linear_unary_helper( + nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default, is_qat=True + ) + self._test_linear_unary_helper( + nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_qat=True + ) + + @skipIfNoX86 + def test_linear_unary_dynamic(self): + aten = torch.ops.aten + self._test_linear_unary_helper( + nn.ReLU, aten.relu.default, aten.relu_.default, is_dynamic=True + ) + self._test_linear_unary_helper( + nn.LeakyReLU, + aten.leaky_relu.default, + aten.leaky_relu_.default, + is_dynamic=True, + ) + self._test_linear_unary_helper( + nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_dynamic=True + ) + + @skipIfNoX86 + def test_linear_unary_dynamic_qat(self): + aten = torch.ops.aten + self._test_linear_unary_helper( + nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True, is_dynamic=True + ) + self._test_linear_unary_helper( + nn.LeakyReLU, + aten.leaky_relu.default, + aten.leaky_relu_.default, + is_qat=True, + is_dynamic=True, + ) + self._test_linear_unary_helper( + nn.GELU, + aten.gelu.default, + None, + ["none", "tanh"], + is_qat=True, + is_dynamic=True, + ) + + def _check_annotation_stat(self, gm, expected_stat_dict): + # Check expected annotation statistics to ensure the annotation is correct + + def _check_annotation(node): + annot = node.meta.get(QUANT_ANNOTATION_KEY, None) + if annot is None: + return False, False + return annot._annotated, annot._is_output_of_quantized_pattern + + for node in gm.graph.nodes: + if node.target in expected_stat_dict.keys(): + annotated, is_quant_out = _check_annotation(node) + expected_stat_dict[node.target]["annotated"] -= annotated + expected_stat_dict[node.target]["is_quant_out"] -= is_quant_out + for op_stat in expected_stat_dict.values(): + assert all(v == 0 for v in op_stat.values()) + + def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): + """ + Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer. + Currently, only add as binary post op is supported. + """ + linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] + # TODO test for inplace add after refactoring of export_for_training + inplace_add_list = [False] + example_inputs = (torch.randn(2, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + ) + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + cases = itertools.product(linear_pos_list, inplace_add_list) + with override_quantized_engine("x86"), torch.no_grad(): + for linear_pos, inplace_add in cases: + m = TestHelperModules.LinearAddModule( + inplace_add=inplace_add, linear_pos=linear_pos + ).eval() + if linear_pos != NodePosType.both: + node_occurrence = { + # Only one 1 q-dq for input of the linear + # No q-dq for extra input node of add + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + else: + # convert_pt2e disables duplicate dequant for dynamic quant + num_dequant = 1 if is_dynamic else 2 + node_occurrence = { + # One quantize_per_tensor for both linear nodes (shared) + # Two dequantize_per_tensor for two linear nodes + # No q-dq for extra input node of add + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: num_dequant, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + quantize_per_tensor_op, + dequantize_per_tensor_op, + torch.ops.aten.linear.default, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), + ] + fq_m = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=is_qat, + )[-1] + # One linear and add are fused. The other linear is quantized alone if present + aten = torch.ops.aten + add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor + expected_annotation_stat = { + aten.linear.default: { + "annotated": 2 if linear_pos == NodePosType.both else 1, + "is_quant_out": 1 if linear_pos == NodePosType.both else 0, + }, + add_op: {"annotated": 1, "is_quant_out": 1}, + } + self._check_annotation_stat(fq_m, expected_annotation_stat) + + @skipIfNoX86 + def test_linear_binary(self): + self._test_linear_binary_helper() + + @skipIfNoX86 + def test_linear_binary_qat(self): + self._test_linear_binary_helper(is_qat=True) + + @skipIfNoX86 + def test_linear_binary_dynamic(self): + self._test_linear_binary_helper(is_dynamic=True) + + @skipIfNoX86 + def test_linear_binary_dynamic_qat(self): + self._test_linear_binary_helper(is_qat=True, is_dynamic=True) + + @skipIfNoX86 + def test_linear_binary2(self): + """ + Test Pattern: + tmp = linear_1(x) + tmp2 = linear_2(tmp) + return tmp + tmp2 + Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1 + """ + example_inputs = (torch.randn(2, 16),) + # TODO test for inplace add after refactoring of export_for_training + inplace_add_list = [False] + is_qat_list = [False, True] + is_dynamic_list = [False, True] + cases = itertools.product(inplace_add_list, is_qat_list, is_dynamic_list) + with override_quantized_engine("x86"), torch.no_grad(): + for inplace_add, is_qat, is_dynamic in cases: + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, is_dynamic=is_dynamic + ) + ) + m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval() + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + # Two q-dq nodes for inputs of linear nodes + # No q-dq for extra input node of add + node_occurrence = { + quantize_per_tensor_op: 2, + dequantize_per_tensor_op: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + quantize_per_tensor_op, + dequantize_per_tensor_op, + torch.ops.aten.linear.default, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), + ] + fq_m = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + )[-1] + # One linear and add are fused. The other linear is quantized alone if present + aten = torch.ops.aten + add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor + expected_annotation_stat = { + aten.linear.default: { + "annotated": 2, + "is_quant_out": 1, + }, + add_op: {"annotated": 1, "is_quant_out": 1}, + } + self._check_annotation_stat(fq_m, expected_annotation_stat) + + @skipIfNoX86 + def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): + """ + Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer. + Currently, only add as binary post op and relu as unary post op are supported. + """ + linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] + # TODO test for inplace add after refactoring of export_for_training + inplace_add_list = [False] + # TODO test for inplace relu after refactoring of export_for_training + inplace_relu_list = [False] + example_inputs = (torch.randn(2, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + ) + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list) + with override_quantized_engine("x86"), torch.no_grad(): + for linear_pos, inplace_add, inplace_relu in cases: + m = TestHelperModules.LinearAddReLUModule( + inplace_add=inplace_add, + linear_pos=linear_pos, + inplace_relu=inplace_relu, + ).eval() + if linear_pos != NodePosType.both: + node_occurrence = { + # Only one q-dq node for input of the linear + # No q-dq node for extra input node of add + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + else: + # convert_pt2e disables duplicate dequant for dynamic quant + num_dequant = 1 if is_dynamic else 2 + node_occurrence = { + # One quantize_per_tensor for both linear nodes (shared) + # Two dequantize_per_tensor for two linear nodes + # No q-dq for extra input node of add + quantize_per_tensor_op: 1, + dequantize_per_tensor_op: num_dequant, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + quantize_per_tensor_op, + dequantize_per_tensor_op, + torch.ops.aten.linear.default, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), + ] + fq_m = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + )[-1] + # linear, add, relu are fused + # The other linear is quantized alone if present + aten = torch.ops.aten + add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor + relu_op = aten.relu_.default if inplace_relu else aten.relu.default + expected_annotation_stat = { + aten.linear.default: { + "annotated": 2 if linear_pos == NodePosType.both else 1, + "is_quant_out": 1 if linear_pos == NodePosType.both else 0, + }, + add_op: {"annotated": 1, "is_quant_out": 0}, + relu_op: {"annotated": 1, "is_quant_out": 1}, + } + self._check_annotation_stat(fq_m, expected_annotation_stat) + + @skipIfNoX86 + def test_linear_binary_unary(self): + self._test_linear_binary_unary_helper() + + @skipIfNoX86 + def test_linear_binary_unary_qat(self): + self._test_linear_binary_unary_helper(is_qat=True) + + @skipIfNoX86 + def test_linear_binary_unary_dynamic(self): + self._test_linear_binary_unary_helper(is_dynamic=True) + + @skipIfNoX86 + def test_linear_binary_unary_dynamic_qat(self): + self._test_linear_binary_unary_helper(is_qat=True, is_dynamic=True) + + @skipIfNoX86 + def test_linear_binary_unary_serials(self): + """ + Test pattern of 2 following up linear add relu with X86InductorQuantizer. + """ + is_qat_list = [False, True] + is_dynamic_list = [False, True] + cases = itertools.product(is_qat_list, is_dynamic_list) + with override_quantized_engine("x86"), torch.no_grad(): + for is_qat, is_dynamic in cases: + m = TestHelperModules.SerialsLinearAddReLUModule().eval() + example_inputs = (torch.randn(2, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=is_qat, + is_dynamic=is_dynamic, + ) + ) + quantize_per_tensor_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequantize_per_tensor_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + # convert_pt2e disables duplicate dequant for dynamic quant + num_dequant = 3 if is_dynamic else 4 + node_occurrence = { + # quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4 + # dequantize_per_tensor: 1 for each linear + # No q-dq for extra input node of add + quantize_per_tensor_op: 3, + dequantize_per_tensor_op: num_dequant, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + quantize_per_tensor_op, + dequantize_per_tensor_op, + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + torch.ops.aten.add.Tensor, + torch.ops.aten.relu.default, + ] + fq_m = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + )[-1] + # Two linear nodes are quantized alone + # The other two are fused with add and relu + aten = torch.ops.aten + expected_annotation_stat = { + aten.linear.default: { + "annotated": 4, + "is_quant_out": 2, + }, + aten.add.Tensor: {"annotated": 2, "is_quant_out": 0}, + aten.relu.default: {"annotated": 2, "is_quant_out": 2}, + } + self._check_annotation_stat(fq_m, expected_annotation_stat) + + @skipIfNoX86 + def test_linear_dynamic_fp16(self): + """ + Test pattern of linear_dynamic_fp16. + """ + with override_quantized_engine("x86"), torch.no_grad(): + for use_bias in [True, False]: + m = TestHelperModules.SingleLinearModule(use_bias).eval() + example_inputs = (torch.randn(2, 4),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + quantizer.set_module_type_qconfig( + torch.nn.Linear, xiq.get_x86_inductor_linear_dynamic_fp16_config() + ) + node_occurrence = { + # 2 convert_element_type nodes are inserted for weight + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + torch.ops.quantized_decomposed.convert_element_type.no_fuse: 2, + } + node_list = [ + torch.ops.quantized_decomposed.convert_element_type.no_fuse, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfTorchDynamo("very slow") + @skipIfNoX86 + def test_qat_conv2d(self): + """ + Test QAT pattern of conv2d_bn with X86InductorQuantizer. + """ + with override_quantized_engine("x86"): + m = TestHelperModules.SingleConv2dModule(with_bn=True) + example_inputs = (torch.randn(2, 3, 16, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config(is_qat=True) + ) + node_occurrence = { + # one for input and weight of the conv, one for output for the conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + # BN should be folded into Conv + torch.ops.aten._native_batch_norm_legit.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + ) + + @skipIfTorchDynamo("very slow") + @skipIfNoX86 + def test_qat_conv2d_unary(self): + """ + Test QAT pattern of conv2d_bn with unary post ops (such as relu, sigmoid) with X86InductorQuantizer. + Currently, only relu as unary post op is supported. + """ + unary_map = { + "relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default], + "relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default], + "hardtanh": [ + torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), + torch.ops.aten.hardtanh.default, + ], + "hardtanh_inplace": [ + torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), + torch.ops.aten.hardtanh_.default, + ], + "relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default], + "relu6_inplace": [ + torch.nn.ReLU6(inplace=True), + torch.ops.aten.hardtanh_.default, + ], + "hardswish": [ + torch.nn.Hardswish(inplace=False), + torch.ops.aten.hardswish.default, + ], + "hardswish_inplace": [ + torch.nn.Hardswish(inplace=True), + torch.ops.aten.hardswish_.default, + ], + "swish": [torch.nn.SiLU(inplace=False), torch.ops.aten.silu.default], + "swish_inplace": [ + torch.nn.SiLU(inplace=True), + torch.ops.aten.silu_.default, + ], + } + + with override_quantized_engine("x86"): + for unary_op in unary_map.keys(): + m = TestHelperModules.Conv2dUnaryModule( + unary_map[unary_op][0], with_bn=True + ) + example_inputs = (torch.randn(2, 3, 16, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config(is_qat=True) + ) + node_occurrence = { + # one for input and weight of the conv, one for output for the relu + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + # BN should be folded into Conv + torch.ops.aten._native_batch_norm_legit.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + unary_map[unary_op][1], + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + ) + + @skipIfTorchDynamo("very slow") + @skipIfNoX86 + def test_qat_conv2d_binary(self): + """ + Test qat pattern of conv2d_bn with binary post ops (such as add) with X86InductorQuantizer. + Currently, only add as binary post op is supported. + """ + example_inputs = (torch.randn(2, 3, 6, 6),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config(is_qat=True) + ) + with override_quantized_engine("x86"): + for inplace_add in [True, False]: + m = TestHelperModules.Conv2dAddModule( + inplace_add=inplace_add, with_bn=True + ) + node_occurrence = { + # one for input and weight of the conv + # one for output for the add + # one for extra input node of add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + # BN should be folded into Conv + torch.ops.aten._native_batch_norm_legit.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + ) + + @skipIfTorchDynamo("very slow") + @skipIfNoX86 + def test_qat_conv2d_binary2(self): + """ + Test qat Pattern: + tmp = bn1(conv2d_1(x)) + tmp2 = bn2(conv2d_2(tmp)) + return tmp + tmp2 + Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1 + """ + example_inputs = (torch.randn(2, 3, 6, 6),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config(is_qat=True) + ) + inplace_add_list = [True, False] + with override_quantized_engine("x86"), torch.no_grad(): + for inplace_add in inplace_add_list: + m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + # BN should be folded into Conv + torch.ops.aten._native_batch_norm_legit.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ( + torch.ops.aten.add_.Tensor + if inplace_add + else torch.ops.aten.add.Tensor + ), + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + ) + + @skipIfTorchDynamo("very slow") + @skipIfNoX86 + def test_qat_conv2d_binary_unary(self): + """ + Test QAT pattern of conv2d_bn with binary + unary post ops (such as add + relu) with X86InductorQuantizer. + Currently, only add as binary post op and relu as unary post op are supported. + """ + example_inputs = (torch.randn(2, 3, 6, 6),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config(is_qat=True) + ) + with override_quantized_engine("x86"): + m = TestHelperModules.Conv2dAddReLUModule(with_bn=True) + node_occurrence = { + # one for input for conv + # one for output for the relu + # one for extra input node of add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + # BN should be folded into Conv + torch.ops.aten._native_batch_norm_legit.default: 0, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + ) + + @skipIfNoX86 + def test_dynamic_quant_linear(self): + """ + Test pattern of dynamic quantization of linear with X86InductorQuantizer. + """ + with override_quantized_engine("x86"), torch.no_grad(): + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config(is_dynamic=True) + ) + node_occurrence = { + torch.ops.quantized_decomposed.choose_qparams.tensor: 1, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + node_list = [ + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_qat_dynamic_quant_linear(self): + """ + Test pattern of qat dynamic quantization of linear with X86InductorQuantizer. + """ + with override_quantized_engine("x86"), torch.no_grad(): + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config( + is_qat=True, is_dynamic=True + ) + ) + node_occurrence = { + torch.ops.quantized_decomposed.choose_qparams.tensor: 1, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + node_list = [ + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=True, + ) + + @skipIfNoX86 + def test_set_module_name_qconfig(self): + """Test case for quantizing a specific submodule by configuring `set_module_name_qconfig`. + + Expect that all linear layers within the submodule `sub` are quantized. + """ + + class Sub(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.relu1 = torch.nn.ReLU(inplace=False) + self.linear2 = torch.nn.Linear(10, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.relu1(x) + x = self.linear2(x) + return x + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set global to `None` and then default config for a specific submodule. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config() + ) + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of two linear layers from `sub` + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # dequantize the weight of two linear layers from `sub` + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # two Q/DQ pairs for two linear layers from `sub` + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_set_module_name_qconfig_with_underscores(self) -> None: + """Test that if a module name has an underscore, we can still quantize it.""" + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + # This module name has underscores, which can be part of a mangled name. + self.foo_bar = torch.nn.Linear(2, 2) + self.baz = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.baz(self.foo_bar(x)) + + # Set global to no quantization and then default config for a specific submodule whose name includes an underscore. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "foo_bar", xiq.get_default_x86_inductor_quantization_config() + ) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = export_for_training(m, example_inputs).module() + m = prepare_pt2e(m, quantizer) + # Use a linear count instead of names because the names might change, but + # the order should be the same. + count = 0 + for n in m.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.linear.default: + # Get the weight observer to see the per-channel vs per-tensor. + weight_observer_node = n.args[1] + if count == 0: + # for foo_bar. + self.assertEqual( + weight_observer_node.op, + "call_module", + f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", + ) + observer_instance = getattr(m, weight_observer_node.target) + self.assertEqual( + observer_instance.qscheme, torch.per_channel_symmetric + ) + else: + # For baz it should have no observer at all. + self.assertNotEqual( + weight_observer_node.op, + "call_module", + f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", + ) + count += 1 + + @skipIfNoX86 + def test_set_module_name_and_module_type_case1(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. + + Expect that all linear layers are not quantized except the last one. + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with default config and then `None` for all `Linear`. + # The config set by `set_module_name_qconfig` has higher priority than `set_module_type_qconfig`. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config() + ).set_module_type_qconfig(torch.nn.Linear, None) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # dequantize the weight of the last linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # first and second linear are not quantized + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + # last linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_set_module_name_and_module_type_case2(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. + + Expect that all linear layers are quantized except the last one. + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with None and then default config for a all `Linear`. + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig("sub", None).set_module_type_qconfig( + torch.nn.Linear, xiq.get_default_x86_inductor_quantization_config() + ) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input and output of the first and second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # dequantize the weight of the first and second linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + # Q/DQ for first lienar + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + # Q/DQ for second lienar + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + # last linear is not quantized + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_set_module_name_qconfig_for_dynamic_quant(self): + """Test that quantize a specific submodule for dynamic quantization.""" + + with override_quantized_engine("x86"), torch.no_grad(): + for is_qat in [False, True]: + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + # only quantize `q_proj` `v_proj` + dynamic_config = xiq.get_default_x86_inductor_quantization_config( + is_dynamic=True, is_qat=is_qat + ) + quantizer = ( + X86InductorQuantizer() + .set_module_name_qconfig("q_proj", dynamic_config) + .set_module_name_qconfig("v_proj", dynamic_config) + ) + node_occurrence = { + # quantize and dequantize the input + torch.ops.quantized_decomposed.choose_qparams.tensor: 1, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # dequantize the weight of q_proj and v_proj + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + # quantize and dequantize the input + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + # q_proj + torch.ops.aten.linear.default, + # k_proj + torch.ops.aten.linear.default, + # v_proj + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=is_qat, + ) + + @skipIfNoX86 + def test_set_module_name_with_mixed_configs(self): + """Test case for setting module names with mixed static/dynamic or QAT/non-QAT configurations. + + The config for 'v_proj' will always be ignored and raise a warning. + """ + with override_quantized_engine("x86"), torch.no_grad(): + with self.assertWarns(UserWarning) as context: + for q_is_dynamic, v_is_dynamic, q_is_qat, v_is_qat in itertools.product( + [False, True], repeat=4 + ): + if q_is_dynamic == v_is_dynamic and q_is_qat == v_is_qat: + continue + m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() + example_inputs = (torch.randn(1, 4, 64),) + quantizer = ( + X86InductorQuantizer() + .set_module_name_qconfig( + "q_proj", + xiq.get_default_x86_inductor_quantization_config( + is_qat=q_is_qat, is_dynamic=q_is_dynamic + ), + ) + .set_module_name_qconfig( + "v_proj", + xiq.get_default_x86_inductor_quantization_config( + is_qat=v_is_qat, is_dynamic=v_is_dynamic + ), + ) + ) + quant_op = ( + torch.ops.quantized_decomposed.quantize_per_tensor.tensor + if q_is_dynamic + else torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + dequant_op = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + if q_is_dynamic + else torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + node_occurrence = { + # quantize and dequantize the input + quant_op: 1, + dequant_op: 1, + # only `q_proj` was quantized, dequantize its weight + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # quantize and dequantize the input + quant_op, + dequant_op, + # q_proj + torch.ops.aten.linear.default, + # k_proj/v_proj + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + is_qat=q_is_qat, + ) + warning_msg = ( + "Mixed QAT and Non-QAT" + if q_is_qat != v_is_qat + else "Mixed dynamic and static" + ) + self.assertTrue( + any( + warning_msg in msg + for msg in [str(w.message) for w in context.warnings] + ) + ) + + @skipIfNoX86 + def test_set_module_name_and_module_type_with_mixed_configs(self): + """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time with mixed the configs. + + Expect that only the last linear(`sub`) is quantized using static quantization. + """ + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(5, 10) + self.linear2 = torch.nn.Linear(10, 5) + self.sub = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + # Set `sub` with static config and then dynamic config for a all `Linear`(ignored). + quantizer = X86InductorQuantizer() + quantizer.set_module_name_qconfig( + "sub", xiq.get_default_x86_inductor_quantization_config(is_dynamic=False) + ).set_module_type_qconfig( + torch.nn.Linear, + xiq.get_default_x86_inductor_quantization_config(is_dynamic=True), + ) + + node_occurrence = { + torch.ops.aten.linear.default: 3, + # quantize and dequantize the input of the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # dequantize the weight of the last linear + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + # first and second linear are not quantized + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + # Q/DQ pairs for the last linear + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_filter_conv2d_recipe(self): + """ + Test removing conv2d from default recipe of X86InductorQuantizer. + """ + with override_quantized_engine("x86"), torch.no_grad(): + m = TestHelperModules.Conv2dUnaryModule(torch.nn.ReLU(inplace=False)).eval() + example_inputs = (torch.randn(2, 3, 16, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + quantizer.set_module_type_qconfig(torch.nn.Conv2d, None) + node_occurrence = { + # one for input and weight of the conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + torch.ops.aten.conv2d.default, + torch.ops.aten.relu.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_filter_linear_recipe(self): + """ + Test removing linear from default recipe of X86InductorQuantizer. + """ + with override_quantized_engine("x86"), torch.no_grad(): + m = TestHelperModules.LinearUnaryModule( + use_bias=True, + postop=nn.ReLU, + ).eval() + example_inputs = (torch.randn(2, 4),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + quantizer.set_function_type_qconfig(torch.nn.functional.linear, None) + node_occurrence = { + # one for input and weight of the conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, + } + node_list = [ + torch.ops.aten.linear.default, + torch.ops.aten.relu.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_filter_maxpool2d_recipe(self): + """ + Test removing maxpool2d from default recipe of X86InductorQuantizer. + """ + with override_quantized_engine("x86"), torch.no_grad(): + m = TestHelperModules.Conv2dUnaryModule(torch.nn.ReLU(inplace=False)).eval() + example_inputs = (torch.randn(2, 3, 16, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + quantizer.set_function_type_qconfig(torch.nn.functional.max_pool2d, None) + node_occurrence = { + # one for input and weight of the conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.relu.default, + torch.ops.aten.max_pool2d.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + @skipIfNoX86 + def test_attention_block(self): + """ + Test pattern of Attention like Block with X86InductorQuantizer. + """ + for annotate_matmul in [False, True]: + with override_quantized_engine("x86"), torch.no_grad(): + m = TestHelperModules.SelfAttnLikeModule( + input_dim=64 * 16, + transpose_for_score=True, + num_attention_heads=16, + attention_head_size=64, + ).eval() + example_inputs = (torch.randn(2, 384, 1024),) + + m(*example_inputs) + + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + + if annotate_matmul: + quantizer.set_function_type_qconfig( + torch.matmul, quantizer.get_global_quantization_config() + ) + + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: ( + 5 if annotate_matmul else 1 + ), + torch.ops.quantized_decomposed.dequantize_per_tensor.default: ( + 7 if annotate_matmul else 3 + ), + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + if annotate_matmul: + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.aten.view.default, + torch.ops.aten.permute.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.matmul.default, + torch.ops.aten.div.Tensor, + torch.ops.aten.softmax.int, + ] + else: + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.aten.view.default, + torch.ops.aten.permute.default, + torch.ops.aten.matmul.default, + torch.ops.aten.div.Tensor, + torch.ops.aten.softmax.int, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) diff --git a/test/quantization/pt2e_flow/test_xnnpack_quantizer.py b/test/quantization/pt2e_flow/test_xnnpack_quantizer.py new file mode 100644 index 0000000000..33b35ffe37 --- /dev/null +++ b/test/quantization/pt2e_flow/test_xnnpack_quantizer.py @@ -0,0 +1,1092 @@ +# Owner(s): ["oncall: mobile"] +import copy +import operator + +import torch +import torch._dynamo as torchdynamo +from torch.ao.ns.fx.utils import compute_sqnr +from torch.ao.quantization import ( + QConfig, + QConfigMapping, + default_dynamic_qconfig, + observer, +) +from torch.ao.quantization.backend_config import get_qnnpack_backend_config +from torch.ao.quantization.qconfig import ( + default_per_channel_symmetric_qnnpack_qconfig, + default_symmetric_qnnpack_qconfig, + per_channel_weight_observer_range_neg_127_to_127, + weight_observer_range_neg_127_to_127, +) +from torch.ao.quantization.quantize_fx import ( + _convert_to_reference_decomposed_fx, + convert_to_reference_fx, + prepare_fx, +) +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, +) +from torch.testing._internal.common_quantization import ( + TestHelperModules, + skip_if_no_torchvision, + skipIfNoQNNPACK, +) +from torch.testing._internal.common_quantized import override_quantized_engine + +from torchao.quantization.pt2e_flow.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, +) +from torchao.testing.pt2e.utils import PT2EQuantizationTestCase + + +@skipIfNoQNNPACK +class TestXNNPACKQuantizer(PT2EQuantizationTestCase): + def test_conv1d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv1d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.ConvWithBNRelu(dim=1, relu=False, bn=False), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_conv2d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.ConvWithBNRelu(relu=False, bn=False), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_conv1d_with_conv2d(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv1d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + m = TestHelperModules.Conv2dThenConv1d() + self._test_quantizer( + m, + m.example_inputs(), + quantizer, + node_occurrence, + node_list, + ) + + def test_linear(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_3d = (torch.randn(9, 10, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_linear_relu(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.LinearReluModel().eval() + + # Test with 2d inputs + example_inputs_2d = (torch.randn(1, 5),) + example_inputs_3d = (torch.randn(1, 2, 5),) + example_inputs_4d = (torch.randn(1, 2, 3, 5),) + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + # There should not be extra quantize_per_tensor or dequantize_per_tensors for relu + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], # node_list + False, # executorch_backend_config() does not fuse linear-relu + qconfig_mapping, + ) + + def test_conv_linear_no_permute(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + self._test_quantizer( + TestHelperModules.Conv2dWithTwoLinear(), + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_conv_linear(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, + } + qconfig = torch.ao.quantization.default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + self._test_quantizer( + TestHelperModules.Conv2dWithTwoLinearPermute(), + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_linear_with_dynamic_shape(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + # Test with 2d inputs + example_inputs_3d = (torch.randn(9, 10, 8),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + self._test_quantizer( + m_eager, + example_inputs_3d, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + export_with_dynamic_shape=True, + ) + + def test_obs_sharing_ops(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m = TestHelperModules.Conv2dWithObsSharingOps().eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.hardtanh.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.mean.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_set_module_name(self): + class Sub(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_module_name("sub", quantization_config) + node_occurrence = { + torch.ops.aten.linear.default: 2, + # input and output for the second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # second linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_set_module_name_with_underscores(self) -> None: + """Test that if a module name has an underscore, we can still quantize it""" + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + # This module name has underscores, which can be part of a mangled + # name. + self.foo_bar = torch.nn.Linear(2, 2) + self.baz = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.baz(self.foo_bar(x)) + + quantizer = XNNPACKQuantizer() + # Set global to no quantization and then per-channel for a specific submodule. + quantizer.set_module_name( + "foo_bar", get_symmetric_quantization_config(is_per_channel=True) + ) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = export_for_training(m, example_inputs).module() + m = prepare_pt2e(m, quantizer) + # Use a linear count instead of names because the names might change, but + # the order should be the same. + count = 0 + for n in m.graph.nodes: + if n.op == "call_function" and n.target == torch.ops.aten.linear.default: + # Get the weight observer to see the per-channel vs per-tensor. + weight_observer_node = n.args[1] + if count == 0: + # The weight tensor should be per-tensor and not per-channel + # for foo_bar. + self.assertEqual(weight_observer_node.op, "call_module") + observer_instance = getattr(m, weight_observer_node.target) + self.assertEqual( + observer_instance.qscheme, torch.per_channel_symmetric + ) + else: + # For baz it should have no observer at all. + self.assertNotEqual(weight_observer_node.op, "call_module") + count += 1 + + def test_set_module_type(self): + class Sub(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Sub() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + return x + + m = M().eval() + example_inputs = (torch.randn(3, 5),) + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_module_type(Sub, quantization_config) + node_occurrence = { + torch.ops.aten.linear.default: 2, + # input and output for the second linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + # first linear is not quantized + torch.ops.aten.linear.default, + # second linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_set_module_type_case_2(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.conv2 = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.conv3 = torch.nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + self.relu = torch.nn.ReLU() + self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.fc = torch.nn.Linear(3, 16) + + def forward(self, x): + x1 = self.conv(x) + x2 = self.relu(self.conv2(x1) + self.conv3(x1)) + x3 = self.avgpool(x2) + x4 = torch.flatten(x3, 1) + x5 = self.fc(x4) + return x5 + + m = M().eval() + example_inputs = (torch.randn(1, 3, 16, 16),) + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + # We only want to annotate Linear type + quantizer.set_module_type(torch.nn.Linear, quantization_config) + node_occurrence = { + torch.ops.aten.conv2d.default: 3, + torch.ops.aten.linear.default: 1, + # input and output for the linear + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + # only the linear is quantized + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list) + + def test_propagate_annotation(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m = TestHelperModules.Conv2dPropAnnotaton().eval() + example_inputs = (torch.randn(1, 3, 5, 5),) + + # program capture + m = export_for_training( + m, + example_inputs, + ).module() + + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + for n in m.graph.nodes: + if n.target in [ + torch.ops.aten.view.default, + torch.ops.aten.hardtanh.default, + ]: + input_act = getattr(m, n.args[0].target) + output_act = getattr(m, next(iter(n.users)).target) + self.assertIs(input_act, output_act) + + m = convert_pt2e(m) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ): 5, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): 5, + # note: quantize op for weights are const propagated + ns.call_function( + torch.ops.quantized_decomposed.quantize_per_channel.default + ): 0, + ns.call_function( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ): 2, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_dynamic_linear(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + act_affine_quant_obs = observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + qconfig = QConfig( + activation=act_affine_quant_obs, + weight=per_channel_weight_observer_range_neg_127_to_127, + ) + qconfig_mapping = QConfigMapping().set_global(qconfig) + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + for example_inputs in [example_inputs_2d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_dynamic_linear_int4_weight(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=True, + weight_qmin=0, + weight_qmax=15, + ) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + act_affine_quant_obs = observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + qconfig = QConfig( + activation=act_affine_quant_obs, + weight=per_channel_weight_observer_range_neg_127_to_127.with_args( + quant_min=0, quant_max=15 + ), + ) + qconfig_mapping = QConfigMapping().set_global(qconfig) + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + for example_inputs in [example_inputs_2d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + ) + + def test_qat_dynamic_linear(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=True, + is_qat=True, + ) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.TwoLinearModule().eval() + + node_occurrence = { + torch.ops.quantized_decomposed.choose_qparams.tensor: 2, + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, + } + act_affine_quant_obs = torch.ao.quantization.default_dynamic_fake_quant + qconfig = QConfig( + activation=act_affine_quant_obs, + weight=per_channel_weight_observer_range_neg_127_to_127, + ) + qconfig_mapping = QConfigMapping().set_global(qconfig) + # Test with 2d inputs + example_inputs_2d = (torch.randn(9, 8),) + example_inputs_4d = (torch.randn(9, 10, 11, 8),) + for example_inputs in [example_inputs_2d, example_inputs_4d]: + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + is_qat=True, + ) + + def test_dynamic_linear_with_conv(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=True + ) + quantizer.set_global(quantization_config) + m_eager = TestHelperModules.ConvLinearWPermute().eval() + + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + } + + training_ir_node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + # In training IR, the decomposition is different. + # `torch.ops.quantized_decomposed.quantize_per_tensor.default` nodes becomes + # `torch.ops.quantized_decomposed.quantize_per_tensor.tensor` nodes. + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } + act_affine_quant_obs = observer.PlaceholderObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_affine, + quant_min=-128, + quant_max=127, + eps=2**-12, + is_dynamic=True, + ) + qconfig = QConfig( + activation=act_affine_quant_obs, + weight=weight_observer_range_neg_127_to_127, + ) + # Test with 2d inputs + example_inputs = (torch.randn(2, 3, 4, 4),) + qconfig_mapping = QConfigMapping().set_global(qconfig) + self._test_quantizer( + m_eager, + example_inputs, + quantizer, + node_occurrence, + [], + True, + qconfig_mapping, + training_ir_node_occurrence=training_ir_node_occurrence, + ) + + def test_gru(self): + """this is a test for annotating fp32 GRU so that it produces + q -> dq -> fp32_gru -> q -> dq, this is currently enough for our use cases, + but we may change the annotation to be more precise in the future + """ + + class RNNDynamicModel(torch.nn.Module): + def __init__(self, mod_type): + super().__init__() + self.qconfig = default_dynamic_qconfig + if mod_type == "GRU": + self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) + if mod_type == "LSTM": + self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) + + def forward(self, input_tensor, hidden_tensor): + input_tensor = 1 * input_tensor + hidden_tensor = 1 * hidden_tensor + output_tensor, hidden_out = self.mod(input_tensor, hidden_tensor) + return 1 * output_tensor, 1 * hidden_out + + with override_quantized_engine("qnnpack"): + model_fx = RNNDynamicModel("GRU") + niter = 10 + example_inputs = ( + # input_tensor + torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) + .unsqueeze(0) + .repeat(niter, 1, 1), + # hidden_tensor + # (D * num_layers, N, H_out) + torch.tensor([[[100, -155]]], dtype=torch.float).repeat(1, 3, 1), + ) + model_graph = copy.deepcopy(model_fx) + + qconfig_mapping = QConfigMapping().set_object_type( + operator.mul, default_symmetric_qnnpack_qconfig + ) + model_fx = prepare_fx( + model_fx, + qconfig_mapping, + example_inputs, + backend_config=get_qnnpack_backend_config(), + ) + model_fx(*example_inputs) + model_fx = _convert_to_reference_decomposed_fx(model_fx) + + with torchdynamo.config.patch(allow_rnn=True): + model_graph = export_for_training( + model_graph, + example_inputs, + ).module() + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=False + ) + quantizer.set_global(quantization_config) + model_graph = prepare_pt2e(model_graph, quantizer) + model_graph(*example_inputs) + model_graph = convert_pt2e(model_graph) + self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs)) + + def test_linear_gru(self): + """this test is to make sure GRU annotation does not interfere with linear annotation""" + + class RNNDynamicModel(torch.nn.Module): + def __init__(self, mod_type): + super().__init__() + self.qconfig = default_dynamic_qconfig + self.linear = torch.nn.Linear(2, 2) + if mod_type == "GRU": + self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) + if mod_type == "LSTM": + self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) + + def forward(self, input_tensor, hidden_tensor): + input_tensor = self.linear(input_tensor) + input_tensor = 1 * input_tensor + hidden_tensor = 1 * hidden_tensor + output_tensor, hidden_out = self.mod(input_tensor, hidden_tensor) + return 1 * output_tensor, 1 * hidden_out + + with override_quantized_engine("qnnpack"): + model_fx = RNNDynamicModel("GRU") + niter = 10 + example_inputs = ( + # input_tensor + torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float) + .unsqueeze(0) + .repeat(niter, 1, 1), + # hidden_tensor + # (D * num_layers, N, H_out) + torch.tensor([[[100, -155]]], dtype=torch.float).repeat(1, 3, 1), + ) + model_graph = copy.deepcopy(model_fx) + + qconfig_mapping = ( + QConfigMapping() + .set_object_type(operator.mul, default_symmetric_qnnpack_qconfig) + .set_object_type(torch.nn.Linear, default_symmetric_qnnpack_qconfig) + ) + model_fx = prepare_fx( + model_fx, + qconfig_mapping, + example_inputs, + backend_config=get_qnnpack_backend_config(), + ) + model_fx(*example_inputs) + model_fx = _convert_to_reference_decomposed_fx(model_fx) + + with torchdynamo.config.patch(allow_rnn=True): + model_graph = export_for_training( + model_graph, + example_inputs, + ).module() + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config( + is_per_channel=False, is_dynamic=False + ) + quantizer.set_global(quantization_config) + model_graph = prepare_pt2e(model_graph, quantizer) + model_graph(*example_inputs) + model_graph = convert_pt2e(model_graph) + self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs)) + + def test_add_and_inplace_add(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = ( + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + ) + node_occurrence = { + # two input and one output for first add, and output for second add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.AddInplaceAdd(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_mul_and_inplace_mul(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = ( + torch.randn(1, 3, 5, 5), + torch.randn(1, 3, 5, 5), + ) + node_occurrence = { + # two input and one output for first add, and output for second add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.MulInplaceMul(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_add_mul_scalar(self): + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + # two input and one output for first add, and output for second add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, + # TODO torch.ops.quantized_decomposed.dequantize_per_tensor.default: 9, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.add.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + # TODO torch.ops.aten.mul.Tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + self._test_quantizer( + TestHelperModules.AddMulScalar(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_mul_float32_max(self): + class M(torch.nn.Module): + def forward(self, x): + return x * 3.4028235e38 + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + # not quantized + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } + node_list = [ + torch.ops.aten.mul.Tensor, + ] + self._test_quantizer( + M(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_add_mul_long(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.t = torch.tensor([100]) + + def forward(self, x): + x = x + self.t + x = x * self.t + return x + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + # not quantized + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } + node_list = [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + ] + self._test_quantizer( + M(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + def test_cat_same_node(self): + """Ensure that concatenating the same node does not cause any unexpected behavior""" + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.cat([x, x]) + return x + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + example_inputs = (torch.randn(1, 3, 5, 5),) + node_occurrence = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.cat.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + self._test_quantizer( + M(), + example_inputs, + quantizer, + node_occurrence, + node_list, + ) + + +# TODO: express this using self._test_quantizer, add test for inception_v4 +class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase): + @skip_if_no_torchvision + @skipIfNoQNNPACK + def test_resnet18(self): + import torchvision + + with override_quantized_engine("qnnpack"): + example_inputs = (torch.randn(1, 3, 224, 224),) + m = torchvision.models.resnet18().eval() + m_copy = copy.deepcopy(m) + # program capture + m = export_for_training( + m, + example_inputs, + ).module() + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(quantization_config) + m = prepare_pt2e(m, quantizer) + # checking that we inserted observers correctly for maxpool operator (input and + # output share observer instance) + self.assertEqual( + id(m.activation_post_process_3), id(m.activation_post_process_2) + ) + after_prepare_result = m(*example_inputs) + m = convert_pt2e(m) + + after_quant_result = m(*example_inputs) + + # comparing with existing fx graph mode quantization reference flow + qconfig = default_per_channel_symmetric_qnnpack_qconfig + qconfig_mapping = QConfigMapping().set_global(qconfig) + backend_config = get_qnnpack_backend_config() + m_fx = prepare_fx( + m_copy, qconfig_mapping, example_inputs, backend_config=backend_config + ) + after_prepare_result_fx = m_fx(*example_inputs) + m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config) + + after_quant_result_fx = m_fx(*example_inputs) + + # the result matches exactly after prepare + # Note: this currently will always be true since we are inserting observers + # the check becomes useful when we add qat examples + # but we can still manully inspect the printed observers to make sure + # it matches + self.assertEqual(after_prepare_result, after_prepare_result_fx) + self.assertEqual( + compute_sqnr(after_prepare_result, after_prepare_result_fx), + torch.tensor(float("inf")), + ) + # there are slight differences after convert due to different implementations + # of quant/dequant + self.assertTrue( + torch.max(after_quant_result - after_quant_result_fx) < 1e-1 + ) + self.assertTrue( + compute_sqnr(after_quant_result, after_quant_result_fx) > 35 + ) diff --git a/torchao/quantization/pt2e_flow/__init__.py b/torchao/quantization/pt2e_flow/__init__.py new file mode 100644 index 0000000000..ffb2972462 --- /dev/null +++ b/torchao/quantization/pt2e_flow/__init__.py @@ -0,0 +1,166 @@ +# mypy: allow-untyped-defs + +from typing import Callable, Optional, Union + +import torch +from torch import Tensor + +from .fake_quantize import ( + FakeQuantize, + FakeQuantizeBase, + FixedQParamsFakeQuantize, + enable_fake_quant, + enable_observer, +) +from .observer import ( + FixedQParamsObserver, + FusedMovingAvgObsFakeQuantize, + Granularity, + HistogramObserver, + MappingType, + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + NoopObserver, + ObserverBase, + PerAxis, + PerBlock, + PerChannelMinMaxObserver, + PerGroup, + PerRow, + PerTensor, + PerToken, + PlaceholderObserver, + RecordingObserver, + ReuseInputObserver, + TorchAODType, + UniformQuantizationObserverBase, + ZeroPointDomain, + get_block_size, +) +from .pt2e._affine_quantization import AffineQuantizedObserverBase +from .pt2e._numeric_debugger import ( # noqa: F401 + CUSTOM_KEY, + NUMERIC_DEBUG_HANDLE_KEY, + compare_results, + extract_results_from_loggers, + generate_numeric_debug_handle, + prepare_for_propagation_comparison, +) +from .pt2e.export_utils import ( + _allow_exported_model_train_eval as allow_exported_model_train_eval, +) +from .pt2e.export_utils import ( + _move_exported_model_to_eval as move_exported_model_to_eval, +) +from .pt2e.export_utils import ( + _move_exported_model_to_train as move_exported_model_to_train, +) +from .qconfig import * # noqa: F403 + +# ensure __module__ is set correctly for public APIs +ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] +ObserverOrFakeQuantize.__module__ = "torchao.quantization.pt2e_flow" + +for _f in [ + compare_results, + extract_results_from_loggers, + generate_numeric_debug_handle, + prepare_for_propagation_comparison, +]: + _f.__module__ = "torchao.quantization.pt2e_flow" + +__all__ = [ + "FakeQuantize", + "FakeQuantizeBase", + "FixedQParamsFakeQuantize", + "FixedQParamsObserver", + "FusedMovingAvgObsFakeQuantize", + "HistogramObserver", + "MinMaxObserver", + "MovingAverageMinMaxObserver", + "MovingAveragePerChannelMinMaxObserver", + "NoopObserver", + "ObserverBase", + "ObserverOrFakeQuantize", + "PerChannelMinMaxObserver", + "PlaceholderObserver", + "RecordingObserver", + "ReuseInputObserver", + "UniformQuantizationObserverBase", + "enable_fake_quant", + "enable_observer", + "move_exported_model_to_eval", + "move_exported_model_to_train", + "allow_exported_model_train_eval", + # pt2e numeric debugger + "generate_numeric_debug_handle", + "CUSTOM_KEY", + "NUMERIC_DEBUG_HANDLE_KEY", + "prepare_for_propagation_comparison", + "extract_results_from_loggers", + "compare_results", + # from torchao, should be merged with torchao + # in the future + "AffineQuantizedObserverBase", + "Granularity", + "MappingType", + "PerAxis", + "PerBlock", + "PerGroup", + "PerRow", + "PerTensor", + "PerToken", + "TorchAODType", + "ZeroPointDomain", + "get_block_size", +] + + +def default_eval_fn(model, calib_data): + r"""Define the default evaluation function. + + Default evaluation function takes a torch.utils.data.Dataset or a list of + input Tensors and run the model on the dataset + """ + for data, _target in calib_data: + model(data) + + +class _DerivedObserverOrFakeQuantize(ObserverBase): + r"""This observer is used to describe an observer whose quantization parameters + are derived from other observers + """ + + def __init__( + self, + dtype: torch.dtype, + obs_or_fqs: list[ObserverOrFakeQuantize], + derive_qparams_fn: Callable[ + [list[ObserverOrFakeQuantize]], tuple[Tensor, Tensor] + ], + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + qscheme: Optional[torch.qscheme] = None, + ch_axis: Optional[int] = None, + ): + super().__init__(dtype) + self.obs_or_fqs = obs_or_fqs + self.derive_qparams_fn = derive_qparams_fn + self.quant_min = quant_min + self.quant_max = quant_max + self.qscheme = qscheme + self.ch_axis = ch_axis + + from .utils import is_per_channel + + if is_per_channel(self.qscheme): + assert ( + self.ch_axis is not None + ), "Must provide a valid ch_axis if qscheme is per channel" + + def forward(self, x: Tensor) -> Tensor: + return x + + def calculate_qparams(self): # type:ignore[override] + return self.derive_qparams_fn(self.obs_or_fqs) diff --git a/torchao/quantization/pt2e_flow/fake_quantize.py b/torchao/quantization/pt2e_flow/fake_quantize.py new file mode 100644 index 0000000000..9803067757 --- /dev/null +++ b/torchao/quantization/pt2e_flow/fake_quantize.py @@ -0,0 +1,650 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +"""Implements modules used to perform fake quantization.""" + +import re +from abc import ABC, abstractmethod +from typing import Any + +import torch +from torch.nn import Module + +from torchao.quantization.pt2e_flow.observer import ( + FixedQParamsObserver, + HistogramObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + _with_args, + default_fixed_qparams_range_0to1_observer, + default_fixed_qparams_range_neg1to1_observer, +) + +__all__ = [ + "FakeQuantizeBase", + "FakeQuantize", + "FixedQParamsFakeQuantize", + "FusedMovingAvgObsFakeQuantize", + "disable_fake_quant", + "disable_observer", + "enable_fake_quant", + "enable_observer", + "default_fake_quant", + "default_weight_fake_quant", + "default_dynamic_fake_quant", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_symmetric_fixed_qparams_fake_quant", + "default_affine_fixed_qparams_fake_quant", + "default_per_channel_weight_fake_quant", + "default_embedding_fake_quant", + "default_embedding_fake_quant_4bit", + "default_histogram_fake_quant", + "default_fused_act_fake_quant", + "default_fused_wt_fake_quant", + "default_fused_per_channel_wt_fake_quant", + "fused_wt_fake_quant_range_neg_127_to_127", + "fused_per_channel_wt_fake_quant_range_neg_127_to_127", +] + + +def _is_per_channel(qscheme: "torch.qscheme") -> bool: + return qscheme in [ + torch.per_channel_symmetric, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ] + + +def _is_per_tensor(qscheme: "torch.qscheme") -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] + + +def _is_symmetric_quant(qscheme: "torch.qscheme") -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric] + + +def _is_float_qparams(qscheme: "torch.qscheme") -> bool: + return qscheme in [ + torch.per_channel_affine_float_qparams, + ] + + +class FakeQuantizeBase(ABC, Module): + r"""Base fake quantize module. + + Base fake quantize module + Any fake quantize implementation should derive from this class. + + Concrete fake quantize module should follow the same API. In forward, they will update + the statistics of the observed Tensor and fake quantize the input. They should also provide a + `calculate_qparams` function that computes the quantization parameters given + the collected statistics. + + """ + + fake_quant_enabled: torch.Tensor + observer_enabled: torch.Tensor + + def __init__(self) -> None: + """Set fake_quant_enabled and observer_enabled.""" + super().__init__() + # fake_quant_enabled and observer_enabled are buffers to support their + # replication in DDP. Data type is uint8 because NCCL does not support + # bool tensors. + self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8)) + self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.uint8)) + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def calculate_qparams(self, **kwargs): + pass + + @torch.jit.export + def enable_fake_quant(self, enabled: bool = True) -> None: + self.fake_quant_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_fake_quant(self): + self.enable_fake_quant(False) + + @torch.jit.export + def enable_observer(self, enabled: bool = True) -> None: + self.observer_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_observer(self): + self.enable_observer(False) + + @classmethod + def with_args(cls, **kwargs): + fake_quant_constructor = _with_args(cls, **kwargs) + # need to assign the correct module to fake_quantize + # constructors to satisfy public v private requirements + fake_quant_constructor.__module__ = ( + "torchao.quantization.pt2e_flow.fake_quantize" + ) + return fake_quant_constructor + + +class FakeQuantize(FakeQuantizeBase): + r"""Simulate the quantize and dequantize operations in training time. + + The output of this module is given by:: + + x_out = ( + clamp(round(x/scale + zero_point), quant_min, quant_max) - zero_point + ) * scale + + * :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization + operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq) + + * :attr:`scale` defines the scale factor used for quantization. + + * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to + + * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that + statistics can still be updated. + + * :attr:`observer_enabled` controls statistics collection on tensors + + * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization, + allowable values are torch.qint8 and torch.quint8. + + Args: + + observer (module): Module for observing statistics on input tensors and calculating scale + and zero-point. + observer_kwargs (optional): Arguments for the observer module + + Attributes: + activation_post_process (Module): User provided module that collects statistics on the input tensor and + provides a method to calculate scale and zero-point. + + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__( + self, + observer=MovingAverageMinMaxObserver, + quant_min=None, + quant_max=None, + is_dynamic=False, + **observer_kwargs, + ): + super().__init__() + # Populate quant_min/quant_max to observer_kwargs if valid + if quant_min is not None and quant_max is not None: + assert ( + quant_min <= quant_max + ), "quant_min must be less than or equal to quant_max" + dtype = observer_kwargs.get("dtype", torch.quint8) + if hasattr(observer, "p"): + # In case observer is _PartialWrapper, dtype can be stored in + # observer.p.keywords["dtype"] + dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get( + "dtype", dtype + ) + assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound" + assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound" + observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max}) + observer_kwargs["is_dynamic"] = is_dynamic + self.activation_post_process = observer(**observer_kwargs) + # TODO: keeping self.quant_min/max for BC; remove after a couple releases + # Users should use self.activation_post_process.quant_min + self.quant_min = self.activation_post_process.quant_min + self.quant_max = self.activation_post_process.quant_max + self.is_dynamic = self.activation_post_process.is_dynamic + if _is_float_qparams(self.activation_post_process.qscheme): + zero_point_dtype = torch.float + else: + zero_point_dtype = torch.int + self.register_buffer("scale", torch.tensor([1.0], dtype=torch.float)) + self.register_buffer("zero_point", torch.tensor([0], dtype=zero_point_dtype)) + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = ( + self.activation_post_process.ch_axis + if hasattr(self.activation_post_process, "ch_axis") + else -1 + ) + assert _is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme), ( + "Only per channel and per tensor quantization are supported in fake quantize" + + " got qscheme: " + + str(self.qscheme) + ) + self.is_per_channel = _is_per_channel(self.qscheme) + + @torch.jit.export + def calculate_qparams(self): + return self.activation_post_process.calculate_qparams() + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.calculate_qparams() + _scale, _zero_point = ( + _scale.to(self.scale.device), + _zero_point.to(self.zero_point.device), + ) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, + self.scale, + self.zero_point, + self.ch_axis, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + ) + else: + X = torch.fake_quantize_per_tensor_affine( + X, + self.scale, + self.zero_point, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + ) + return X + + @torch.jit.export + def extra_repr(self): + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " + f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, " + f"scale={self.scale}, zero_point={self.zero_point}" + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + # We cannot currently register scalar values as buffers, so need to manually + # specify serialization here. + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = self.scale + destination[prefix + "zero_point"] = self.zero_point + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + # Removing this function throws an error that the size of the loaded tensor does not match the original size + # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass. + local_state = ["scale", "zero_point"] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading scale and zero_point + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == "scale": + self.scale.resize_(val.shape) + else: + assert name == "zero_point" + self.zero_point.resize_(val.shape) + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == "scale": + self.scale.copy_(val) + else: + assert name == "zero_point" + self.zero_point.copy_(val) + elif strict: + missing_keys.append(key) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class FixedQParamsFakeQuantize(FakeQuantize): + """Simulate quantize and dequantize in training time. + + Simulate quantize and dequantize with fixed quantization + parameters in training time. Only per tensor quantization + is supported. + """ + + # TODO: rename observer to observer_ctr + def __init__(self, observer): + super().__init__(observer=observer) + assert ( + type(self.activation_post_process) == FixedQParamsObserver + ), f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}" + self._observer_ctr = observer + self.scale = self.activation_post_process.scale + self.zero_point = self.activation_post_process.zero_point + assert _is_per_tensor(self.qscheme), ( + "Only per tensor quantization is supported" + + " FixedQParamsFakeQuantize module, got qscheme:" + + str(self.qscheme) + ) + + @torch.jit.export + def calculate_qparams(self): + return self.scale, self.zero_point + + @torch.jit.export + def extra_repr(self): + """Define a string representation of the object's attributes.""" + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"scale={self.scale}, zero_point={self.zero_point}, " + f"dtype={self.dtype}, quant_min={self.activation_post_process.quant_min}, " + f"quant_max={self.activation_post_process.quant_max}, qscheme={self.qscheme}" + ) + + +class FusedMovingAvgObsFakeQuantize(FakeQuantize): + r"""Define a fused module to observe the tensor. + + Fused module that is used to observe the input tensor (compute min/max), compute + scale/zero_point and fake_quantize the tensor. + This module uses calculation similar MovingAverageMinMaxObserver for the inputs, + to compute the min/max values in order to compute the scale/zero_point. + The qscheme input in the observer is used to differentiate between symmetric/affine + quantization scheme. + + The output of this module is given by + x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale + + Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the + base class. + + """ + + def __init__( + self, + observer: Any = MovingAverageMinMaxObserver, + quant_min: int = 0, + quant_max: int = 255, + **observer_kwargs: Any, + ) -> None: + super().__init__(observer, quant_min, quant_max, **observer_kwargs) + assert isinstance( + self.activation_post_process, + (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver), + ), "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" + self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long)) + self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long)) + self.is_symmetric_quant = _is_symmetric_quant( + self.activation_post_process.qscheme + ) + + @torch.jit.export + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + return self.activation_post_process.calculate_qparams() + + @torch.jit.export + def extra_repr(self) -> str: + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}, " + f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " + f"qscheme={self.qscheme}, reduce_range={self.activation_post_process.reduce_range}" + ) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + return torch.fused_moving_avg_obs_fake_quant( + X, + self.observer_enabled, + self.fake_quant_enabled, + self.activation_post_process.min_val, + self.activation_post_process.max_val, + self.scale, + self.zero_point, + self.activation_post_process.averaging_constant, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + self.ch_axis, + self.is_per_channel, + self.is_symmetric_quant, + ) + + +default_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, +) +""" +Default fake_quant for activations. +""" + +default_weight_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + reduce_range=False, +) +""" +Default fake_quant for weights. +Observer is memoryless since averaging_constant is 1. +""" + +default_dynamic_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + is_dynamic=True, + dtype=torch.quint8, + averaging_constant=1, +) +""" +Default dynamic fake_quant for activations. +""" + +default_fixed_qparams_range_neg1to1_fake_quant = FixedQParamsFakeQuantize.with_args( + observer=default_fixed_qparams_range_neg1to1_observer +) +default_fixed_qparams_range_0to1_fake_quant = FixedQParamsFakeQuantize.with_args( + observer=default_fixed_qparams_range_0to1_observer +) +# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases +default_symmetric_fixed_qparams_fake_quant = ( + default_fixed_qparams_range_neg1to1_fake_quant +) +default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant + +default_per_channel_weight_fake_quant = FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + ch_axis=0, +) +""" +Default fake_quant for per-channel weights. +Observer is memoryless since averaging_constant is 1. +""" +default_embedding_fake_quant = FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + qscheme=torch.per_channel_affine_float_qparams, + dtype=torch.quint8, + quant_min=0, + quant_max=255, + ch_axis=0, + averaging_constant=1, +) +""" +Default fake_quant for embeddings. +Observer is memoryless since averaging_constant is 1. +""" + +default_embedding_fake_quant_4bit = FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0, + dtype=torch.quint4x2, + averaging_constant=1, +) + +default_histogram_fake_quant = FakeQuantize.with_args( + observer=HistogramObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, +) +""" +Fake_quant for activations using a histogram.. +""" + + +default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, +) + +""" +Fused version of `default_fake_quant`, with improved performance. +""" + + +default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, +) +""" +Fused version of `default_weight_fake_quant`, with improved performance. +""" + +default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, +) +""" +Fused version of `default_per_channel_weight_fake_quant`, with improved performance. +""" + +fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-127, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + eps=2**-12, +) +""" +Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +fused_per_channel_wt_fake_quant_range_neg_127_to_127 = ( + FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-127, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + eps=2**-12, + ) +) + +""" +Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128. +""" + + +def _is_fake_quant_script_module(mod): + """Return true if given mod is an instance of FakeQuantize script module.""" + if isinstance(mod, torch.jit.RecursiveScriptModule): + # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize' + suffix = mod._c.qualified_name.split(".", 1)[1] + name = re.sub(r"\.___torch_mangle_\d+", "", suffix) + return ( + name == "torch.ao.quantization.fake_quantize.FakeQuantize" + or name + == "torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize" + ) + return False + + +def disable_fake_quant(mod): + """Disable fake quantization for the module. + + Disable fake quantization for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.disable_fake_quant) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.disable_fake_quant() + + +def enable_fake_quant(mod): + """Enable fake quantization for the module. + + Enable fake quantization for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.enable_fake_quant) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.enable_fake_quant() + + +def disable_observer(mod): + """Disable observation for this module. + + Disable observation for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.disable_observer) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.disable_observer() + + +def enable_observer(mod): + """Enable observation for this module. + + Enable observation for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.enable_observer) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.enable_observer() diff --git a/torchao/quantization/pt2e_flow/observer.py b/torchao/quantization/pt2e_flow/observer.py new file mode 100644 index 0000000000..b3b16c14c5 --- /dev/null +++ b/torchao/quantization/pt2e_flow/observer.py @@ -0,0 +1,2050 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# temporarily skip RUF for this file for now, we can re-enable +# after move the affine quantization related things to torchao +# noqa: RUF +""" +This module implements observers which are used to collect statistics about +the values observed during calibration (PTQ) or training (QAT). +""" + +import re +import warnings +from abc import ABCMeta, abstractmethod +from collections import OrderedDict +from functools import partial +from typing import Any, Optional + +import torch +import torch.nn as nn + +import torchao +from torchao.quantization.pt2e_flow.utils import ( + calculate_qmin_qmax, + check_min_max_valid, + is_per_channel, + is_per_tensor, + validate_qmin_qmax, +) + +__all__ = [ + "default_affine_fixed_qparams_observer", + "default_debug_observer", + "default_dynamic_quant_observer", + "default_fixed_qparams_range_0to1_observer", + "default_fixed_qparams_range_neg1to1_observer", + "default_float_qparams_observer", + "default_float_qparams_observer_4bit", + "default_histogram_observer", + "default_observer", + "default_per_channel_weight_observer", + "default_placeholder_observer", + "default_reuse_input_observer", + "default_symmetric_fixed_qparams_observer", + "default_weight_observer", + "get_observer_state_dict", + "load_observer_state_dict", + "per_channel_weight_observer_range_neg_127_to_127", + "weight_observer_range_neg_127_to_127", + "FixedQParamsObserver", + "HistogramObserver", + "MinMaxObserver", + "MovingAverageMinMaxObserver", + "MovingAveragePerChannelMinMaxObserver", + "NoopObserver", + "ObserverBase", + "PerChannelMinMaxObserver", + "PlaceholderObserver", + "RecordingObserver", + "ReuseInputObserver", + "UniformQuantizationObserverBase", + "AffineQuantizedObserverBase", + "Granularity", + "MappingType", + "PerAxis", + "PerBlock", + "PerGroup", + "PerRow", + "PerTensor", + "PerToken", + "TorchAODType", + "ZeroPointDomain", + "get_block_size", +] + + +class _PartialWrapper: + def __init__(self, p): + self.p = p + self.callable_args = {} + + def __call__(self, *args, **keywords): + # call each arg in callable_args and add them partial, then run with keywords + # skip if arg_name in keywords so its possible to overwrite + for arg_name in self.callable_args: + if arg_name not in keywords: + keywords = {**keywords, arg_name: self.callable_args[arg_name]()} + return self.p(*args, **keywords) + + def __repr__(self): + return self.p.__repr__() + self.callable_args.__repr__() + + def with_args(self, **kwargs): + return _with_args(self, **kwargs) + + def with_callable_args(self, **kwargs): + result = _PartialWrapper(p=self.p) + result.callable_args = {**self.callable_args, **kwargs} + return result + + +def _with_args(cls_or_self, **kwargs): + r"""Wrapper that allows creation of class factories. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances. Can be used in conjunction with + _callable_args + + Example:: + + >>> # xdoctest: +SKIP("Undefined vars") + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) + >>> foo_instance1 = foo_builder() + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1) == id(foo_instance2) + False + """ + r = _PartialWrapper(partial(cls_or_self, **kwargs)) + return r + + +def _with_callable_args(cls_or_self, **kwargs): + r"""Wrapper that allows creation of class factories args that need to be + called at construction time. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances and those arguments should only + be calculated at construction time. Can be used in conjunction with _with_args + + Example:: + + >>> # xdoctest: +SKIP("Undefined vars") + >>> Foo.with_callable_args = classmethod(_with_callable_args) + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan") + >>> foo_instance1 = foo_builder() + >>> # wait 50 + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time) + False + """ + r = _PartialWrapper(partial(cls_or_self)) + return r.with_callable_args(**kwargs) + + +ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + + +class ObserverBase(ABC, nn.Module): + r"""Base observer Module. + Any observer implementation should derive from this class. + + Concrete observers should follow the same API. In forward, they will update + the statistics of the observed Tensor. And they should provide a + `calculate_qparams` function that computes the quantization parameters given + the collected statistics. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + is_dynamic: indicator for whether the observer is a placeholder for dynamic quantization + or static quantization + """ + + def __init__(self, dtype, is_dynamic: bool = False): + super().__init__() + self.dtype = dtype + self.is_dynamic = is_dynamic + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def calculate_qparams(self, **kwargs): + pass + + with_args = classmethod(_with_args) + with_callable_args = classmethod(_with_callable_args) + + +class UniformQuantizationObserverBase(ObserverBase): + r"""Common base for all observers using uniform quantization to calculate + scale and zero_point. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used. + reduce_range: Reduces the range of the quantized data type by 1 bit. + This is sometimes required to avoid instruction overflow. + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + .. warning:: + + :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. + or `torch.int8` or `torch.uint8` + + .. warning:: + + :attr:`qscheme` can only take one of the following options: + + - ``torch.per_tensor_affine`` + - ``torch.per_tensor_symmetric`` + - ``torch.per_channel_affine`` + - ``torch.per_channel_symmetric`` + """ + + # Note: the version is shared by all observer types + # + # Version 1/None + # self + # + # Version 2 (base class only, does not include child class buffers) + # self + # |--- eps : Tensor + # + # Version 3 + # for HistogramObserver only, changed the shape of uninitialized + # min_val and max_val buffers from torch.Size([0]) to torch.Size([]) + # for PerChannelObservers, changed the name of the buffers from min_vals + # to min_val and from max_vals to max_val. + _version = 3 + + eps: torch.Tensor + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) + self.qscheme = qscheme + if reduce_range: + warnings.warn( + "Please use quant_min and quant_max to specify the range for observers. \ + reduce_range will be deprecated in a future release of PyTorch." + ) + self.reduce_range = reduce_range + self.register_buffer("eps", torch.tensor([eps], **factory_kwargs)) + assert self.qscheme in ( + torch.per_tensor_affine, + torch.per_tensor_symmetric, + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, + ), "Default Observer only works for per_tensor_affine, \ + per_tensor_symmetric, per_channel_affine, \ + per_channel_symmetric and per_channel_float_qparams quantization scheme" + + _ALLOWED_DTYPES = ( + torch.qint8, + torch.quint8, + torch.quint4x2, + torch.qint32, + torch.int8, + torch.uint8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.uint16, + ) + + assert ( + self.dtype in _ALLOWED_DTYPES + ), f"Default Observer only works for {_ALLOWED_DTYPES} data type" + self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) + if self.has_customized_qrange: + validate_qmin_qmax(quant_min, quant_max) + self.quant_min, self.quant_max = calculate_qmin_qmax( + quant_min, + quant_max, + self.has_customized_qrange, + self.dtype, + self.reduce_range, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version == 1: + # eps was moved to a buffer in version 2 + eps = torch.tensor([torch.finfo(torch.float32).eps]) + state_dict[prefix + "eps"] = eps + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None: + r"""Validates that the user-specified quantization range is properly initialized + and within the given bound supported by the observer dtype. + + To accommodate lower-bit quantization with respect to the existing torch.qint8 and + torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing + in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax + values are used to calculate static estimates of the scale and zero point for aggressive lower-bit + fake quantization. These estimates are compared against parameters learned through backpropagation. + The related literatures for scale and zero point via backpropagation are as follows: + + Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS + Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf + """ + # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted + # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. + assert ( + quant_min <= 0 <= quant_max + ), "Used-specified quantization range must include 0." + assert ( + quant_min < quant_max + ), "qmin must be strictly less than qmax for user-specified quantization range." + + @torch.jit.export + def _calculate_qparams( + self, min_val: torch.Tensor, max_val: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters, given min and max + value tensors. Works for both per tensor and per channel cases + + Args: + min_val: Minimum values per channel + max_val: Maximum values per channel + + Returns: + scales: Scales tensor of shape (#channels,) + zero_points: Zero points tensor of shape (#channels,) + """ + # Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme + # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer + # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code + # seems unlikey to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. + # TODO(jakeszwe, jerryzh168) + if not check_min_max_valid(min_val, max_val): + return torch.tensor([1.0], device=min_val.device.type), torch.tensor( + [0], device=min_val.device.type + ) + + quant_min, quant_max = self.quant_min, self.quant_max + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + device = min_val_neg.device + scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + if ( + self.qscheme == torch.per_tensor_symmetric + or self.qscheme == torch.per_channel_symmetric + ): + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = torch.max(scale, self.eps) + if self.dtype in [torch.quint8, torch.uint8]: + if self.has_customized_qrange: + # When customized quantization range is used, down-rounded midpoint of the range is chosen. + zero_point = zero_point.new_full( + zero_point.size(), (quant_min + quant_max) // 2 + ) + else: + zero_point = zero_point.new_full(zero_point.size(), 128) + elif self.dtype in [torch.uint16]: + zero_point = zero_point.new_full(zero_point.size(), 2**15) + elif self.qscheme == torch.per_channel_affine_float_qparams: + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.where(scale > self.eps, scale, torch.ones_like(scale)) + # We use the quantize function + # xq = Round(Xf * inv_scale + zero_point), + # setting zero_point to (-1 * min *inv_scale) we get + # Xq = Round((Xf - min) * inv_scale) + zero_point = -1 * min_val / scale + else: + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, self.eps) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # For scalar values, cast them to Tensors of size 1 to keep the shape + # consistent with default values in FakeQuantize. + if len(scale.shape) == 0: + # TODO: switch to scale.item() after adding JIT support + scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) + if len(zero_point.shape) == 0: + # TODO: switch to zero_point.item() after adding JIT support + zero_point = torch.tensor( + [int(zero_point)], dtype=zero_point.dtype, device=device + ) + if self.qscheme == torch.per_channel_affine_float_qparams: + zero_point = torch.tensor( + [float(zero_point)], dtype=zero_point.dtype, device=device + ) + + return scale, zero_point + + @torch.jit.export + def reset_min_max_vals(self): + raise NotImplementedError("Cannot reset min/max values in the given observer.") + + +# Originally, this class was called `_ObserverBase`. Keeping the old name around +# for backwards compatibility. +# TODO(after v1.13): delete this +_ObserverBase = UniformQuantizationObserverBase + + +class MinMaxObserver(UniformQuantizationObserverBase): + r"""Observer module for computing the quantization parameters based on the + running min and max values. + + This observer uses the tensor min/max statistics to compute the quantization + parameters. The module records the running minimum and maximum of incoming + tensors, and uses this statistic to compute the quantization parameters. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`, + scale :math:`s` and zero point :math:`z` are computed as: + + The running minimum/maximum :math:`x_\text{min/max}` is computed as: + + .. math:: + + \begin{array}{ll} + x_\text{min} &= \begin{cases} + \min(X) & \text{if~}x_\text{min} = \text{None} \\ + \min\left(x_\text{min}, \min(X)\right) & \text{otherwise} + \end{cases}\\ + x_\text{max} &= \begin{cases} + \max(X) & \text{if~}x_\text{max} = \text{None} \\ + \max\left(x_\text{max}, \max(X)\right) & \text{otherwise} + \end{cases}\\ + \end{array} + + where :math:`X` is the observed tensor. + + The scale :math:`s` and zero point :math:`z` are then computed as: + + .. math:: + + \begin{aligned} + \text{if Symmetric:}&\\ + &s = 2 \max(|x_\text{min}|, x_\text{max}) / + \left( Q_\text{max} - Q_\text{min} \right) \\ + &z = \begin{cases} + 0 & \text{if dtype is qint8} \\ + 128 & \text{otherwise} + \end{cases}\\ + \text{Otherwise:}&\\ + &s = \left( x_\text{max} - x_\text{min} \right ) / + \left( Q_\text{max} - Q_\text{min} \right ) \\ + &z = Q_\text{min} - \text{round}(x_\text{min} / s) + \end{aligned} + + where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and + maximum of the quantized data type. + + .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. + + .. note:: If the running minimum equals to the running maximum, the scale + and zero_point are set to 1.0 and 0. + """ + + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + "MinMaxObserver's qscheme only support torch.per_tensor_symmetric \ + and torch.per_tensor_affine." + ) + # TODO: MinMaxObserver by itself doesn't support dynamic quantization, but + # if it's inherited by MovingAverageObserver, and averaging_constant is 1, it + # supports dynamic quantization, we may need to better error checking here + + # For x86 quantized kernels, we need to ensure that the vpmaddubsw + # instruction does not overflow. We allow for a reduce_range argument to + # observers that reduces the quantized range to (0,127) or (-64, 63). + # For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp + # This is not an optimal choice for non x86 backends as it loses a bit + # of precision for activations. + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + if ( + self.qscheme == torch.per_tensor_symmetric + and self.reduce_range + and self.dtype == torch.quint8 + ): + raise NotImplementedError( + "Cannot reduce range for symmetric \ + quantization for quint8" + ) + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = torch.min(min_val_cur, self.min_val) + max_val = torch.max(max_val_cur, self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + r"""Calculates the quantization parameters.""" + return self._calculate_qparams(self.min_val, self.max_val) + + @torch.jit.export + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + self.min_val.copy_(torch.tensor(float("inf"))) + self.max_val.copy_(torch.tensor(float("-inf"))) + + +class MovingAverageMinMaxObserver(MinMaxObserver): + r"""Observer module for computing the quantization parameters based on the + moving average of the min and max values. + + This observer computes the quantization parameters based on the moving + averages of minimums and maximums of the incoming tensors. The module + records the average minimum and maximum of incoming tensors, and uses this + statistic to compute the quantization parameters. + + Args: + averaging_constant: Averaging constant for min/max. + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The moving average min/max is computed as follows + + .. math:: + + \begin{array}{ll} + x_\text{min} = \begin{cases} + \min(X) & \text{if~}x_\text{min} = \text{None} \\ + (1 - c) x_\text{min} + c \min(X) & \text{otherwise} + \end{cases}\\ + x_\text{max} = \begin{cases} + \max(X) & \text{if~}x_\text{max} = \text{None} \\ + (1 - c) x_\text{max} + c \max(X) & \text{otherwise} + \end{cases}\\ + \end{array} + + where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is + is the incoming tensor, and :math:`c` is the ``averaging_constant``. + + The scale and zero point are then computed as in + :class:`~torchao.quantization.pt2e_flow.observer.MinMaxObserver`. + + .. note:: Only works with ``torch.per_tensor_affine`` quantization scheme. + + .. note:: If the running minimum equals to the running maximum, the scale + and zero_point are set to 1.0 and 0. + """ + + def __init__( + self, + averaging_constant=0.01, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + f"MovingAverageMinMaxObserver's qscheme only support \ + torch.per_tensor_symmetric and torch.per_tensor_affine. \ + but got: {qscheme}" + ) + self.averaging_constant = averaging_constant + if is_dynamic and self.averaging_constant != 1: + raise NotImplementedError( + "MovingAverageMinMaxObserver doesn't support dynamic quantization for " + f"averaging constant of {self.averaging_constant}" + ) + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val = self.min_val + max_val = self.max_val + if min_val == float("inf") and max_val == float("-inf"): + min_val, max_val = torch.aminmax(x) + else: + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = min_val + self.averaging_constant * (min_val_cur - min_val) + max_val = max_val + self.averaging_constant * (max_val_cur - max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + +class PerChannelMinMaxObserver(UniformQuantizationObserverBase): + r"""Observer module for computing the quantization parameters based on the + running per channel min and max values. + + This observer uses the tensor min/max statistics to compute the per channel + quantization parameters. The module records the running minimum and maximum + of incoming tensors, and uses this statistic to compute the quantization + parameters. + + Args: + ch_axis: Channel axis + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The quantization parameters are computed the same way as in + :class:`~torchao.quantization.pt2e_flow.observer.MinMaxObserver`, with the difference + that the running min/max values are stored per channel. + Scales and zero points are thus computed per channel as well. + + .. note:: If the running minimum equals to the running maximum, the scales + and zero_points are set to 1.0 and 0. + """ + + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_channel(qscheme): + raise NotImplementedError( + "PerChannelMinMaxObserver's qscheme only support \ + torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." + ) + if is_dynamic: + raise NotImplementedError( + "PerChannelMinMaxObserver doesn't support dynamic quantization" + ) + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.ch_axis = ch_axis + self.register_buffer("min_val", torch.tensor([], **factory_kwargs)) + self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) + if ( + self.qscheme == torch.per_channel_symmetric + and self.reduce_range + and self.dtype == torch.quint8 + ): + raise NotImplementedError( + "Cannot reduce range for symmetric quantization for quint8" + ) + + def forward(self, x_orig): + return self._forward(x_orig) + + def _forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + return self._calculate_qparams(self.min_val, self.max_val) + + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + def _load_from_state_dict( + self, + state_dict: dict[str, Any], + prefix: str, + local_metadata: dict[str, torch.Tensor], + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ): + version = local_metadata.get("version", None) + if version is not None and version < 3: + local_state = ["min_vals", "max_vals"] + expected_min_name = "min_vals" + expected_max_name = "max_vals" + else: + local_state = ["min_val", "max_val"] + expected_min_name = "min_val" + expected_max_name = "max_val" + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading min_val or max_val + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == expected_min_name: + self.min_val.resize_(val.shape) + elif name == expected_max_name: + self.max_val.resize_(val.shape) + else: + warnings.warn( + f"Observer load_from_state_dict got unexpected name {name}" + ) + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == expected_min_name: + self.min_val.copy_(val) + elif name == expected_max_name: + self.max_val.copy_(val) + else: + warnings.warn( + f"Observer load_from_state_dict got unexpected name {name}" + ) + elif strict: + missing_keys.append(key) + + if not torch.jit.is_scripting(): + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _load_from_state_dict_script( + self, + state_dict: dict[str, Any], + prefix: str, + local_metadata: dict[str, torch.Tensor], + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ): + self._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + # This used to be torch.ones but that does not work because + # JIT compiler can optimize it via common subexpression elimination + # in which case both min_val and max_val point to the same tensor. + self.min_val = torch.rand( + 0, + ) + self.max_val = torch.rand( + 0, + ) + + +class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): + r"""Observer module for computing the quantization parameters based on the + running per channel min and max values. + + This observer uses the tensor min/max statistics to compute the per channel + quantization parameters. The module records the running minimum and maximum + of incoming tensors, and uses this statistic to compute the quantization + parameters. + + Args: + averaging_constant: Averaging constant for min/max. + ch_axis: Channel axis + dtype: Quantized data type + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The quantization parameters are computed the same way as in + :class:`~torchao.quantization.pt2e_flow.observer.MovingAverageMinMaxObserver`, with the + difference that the running min/max values are stored per channel. + Scales and zero points are thus computed per channel as well. + + .. note:: If the running minimum equals to the running maximum, the scales + and zero_points are set to 1.0 and 0. + """ + + def __init__( + self, + averaging_constant=0.01, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_channel(qscheme): + raise NotImplementedError( + "MovingAveragePerChannelMinMaxObserver's qscheme only support \ + torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." + ) + if is_dynamic: + raise NotImplementedError( + "MovingAveragePerChannelMinMaxObserver doesn't support dynamic quantization" + ) + super().__init__( + ch_axis=ch_axis, + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + self.averaging_constant = averaging_constant + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = min_val + self.averaging_constant * (min_val_cur - min_val) + max_val = max_val + self.averaging_constant * (max_val_cur - max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + +class HistogramObserver(UniformQuantizationObserverBase): + r""" + The module records the running histogram of tensor values along with + min/max values. ``calculate_qparams`` will calculate scale and zero_point. + + Args: + bins: Number of bins to use for the histogram + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The scale and zero point are computed as follows: + + 1. Create the histogram of the incoming inputs. + The histogram is computed continuously, and the ranges per bin change + with every new tensor observed. + 2. Search the distribution in the histogram for optimal min/max values. + The search for the min/max values ensures the minimization of the + quantization error with respect to the floating point model. + 3. Compute the scale and zero point the same way as in the + :class:`~torchao.quantization.pt2e_flow.MinMaxObserver` + """ + + histogram: torch.Tensor + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + bins: int = 2048, + dtype: torch.dtype = torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + "HistogramObserver's qscheme only support torch.per_tensor_symmetric \ + and torch.per_tensor_affine." + ) + if is_dynamic: + raise NotImplementedError( + "HistogramObserver doesn't support dynamic quantization" + ) + # bins: The number of bins used for histogram calculation. + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.bins = bins + self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs)) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits + self.upsample_rate = ( + 16 # used to reduce quantization errors when upscaling histogram + ) + + def _get_norm( + self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor + ) -> torch.Tensor: + r""" + Compute the norm of the values uniformaly distributed between + delta_begin and delta_end. + Currently only L2 norm is supported. + + norm = density * (integral_{begin, end} x^2) + = density * (end^3 - begin^3) / 3 + """ + norm = ( + delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin + ) / 3 + return density * norm + + def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int): + r""" + Compute the quantization error if we use start_bin to end_bin as the + min and max to do the quantization. + """ + bin_width = (self.max_val.item() - self.min_val.item()) / self.bins + + dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins + if dst_bin_width == 0.0: + return 0.0 + + src_bin = torch.arange(self.bins, device=self.histogram.device) + # distances from the beginning of first dst_bin to the beginning and + # end of src_bin + src_bin_begin = (src_bin - next_start_bin) * bin_width + src_bin_end = src_bin_begin + bin_width + + # which dst_bins the beginning and end of src_bin belong to? + dst_bin_of_begin = torch.clamp( + torch.div(src_bin_begin, dst_bin_width, rounding_mode="floor"), + 0, + self.dst_nbins - 1, + ) + dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width + + dst_bin_of_end = torch.clamp( + torch.div(src_bin_end, dst_bin_width, rounding_mode="floor"), + 0, + self.dst_nbins - 1, + ) + density = self.histogram / bin_width + + norm = torch.zeros(self.bins, device=self.histogram.device) + + delta_begin = src_bin_begin - dst_bin_of_begin_center + delta_end = dst_bin_width / 2 + norm += self._get_norm( + delta_begin, + torch.ones(self.bins, device=self.histogram.device) * delta_end, + density, + ) + + norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm( + torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density + ) + + dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2 + + delta_begin = -dst_bin_width / 2 + delta_end = src_bin_end - dst_bin_of_end_center + norm += self._get_norm(torch.tensor(delta_begin), delta_end, density) + + return norm.sum().item() + + def _non_linear_param_search(self) -> tuple[torch.Tensor, torch.Tensor]: + r"""Non-linear parameter search. + + An approximation for L2 error minimization for selecting min/max. + By selecting new min/max, we filter out outliers in input distribution. + This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in + caffe2/quantization/server/norm_minimization.cc + """ + assert self.histogram.size()[0] == self.bins, "bins mismatch" + bin_width = (self.max_val - self.min_val) / self.bins + + # cumulative sum + total = torch.sum(self.histogram).item() + cSum = torch.cumsum(self.histogram, dim=0) + + stepsize = 1e-5 # granularity + alpha = 0.0 # lower bound + beta = 1.0 # upper bound + start_bin = 0 + end_bin = self.bins - 1 + norm_min = float("inf") + + while alpha < beta: + # Find the next step + next_alpha = alpha + stepsize + next_beta = beta - stepsize + + # find the left and right bins between the quantile bounds + l = start_bin + r = end_bin + while l < end_bin and cSum[l] < next_alpha * total: + l = l + 1 + while r > start_bin and cSum[r] > next_beta * total: + r = r - 1 + + # decide the next move + next_start_bin = start_bin + next_end_bin = end_bin + if (l - start_bin) > (end_bin - r): + # move the start bin + next_start_bin = l + alpha = next_alpha + else: + # move the end bin + next_end_bin = r + beta = next_beta + + if next_start_bin == start_bin and next_end_bin == end_bin: + continue + + # calculate the quantization error using next_start_bin and next_end_bin + norm = self._compute_quantization_error(next_start_bin, next_end_bin) + + if norm > norm_min: + break + norm_min = norm + start_bin = next_start_bin + end_bin = next_end_bin + + new_min = self.min_val + bin_width * start_bin + new_max = self.min_val + bin_width * (end_bin + 1) + return new_min, new_max + + def _upscale_histogram( + self, + histogram: torch.Tensor, + orig_min: torch.Tensor, + orig_max: torch.Tensor, + update_min: torch.Tensor, + update_max: torch.Tensor, + ): + # this turns the histogram into a more fine-coarsed histogram to reduce + # bin quantization errors + histogram = histogram.repeat_interleave(self.upsample_rate) / self.upsample_rate + bin_size = (orig_max - orig_min) / (self.bins * self.upsample_rate) + mid_points_histogram = ( + torch.linspace( + orig_min, + orig_max, + self.bins * self.upsample_rate + 1, + device=orig_min.device, + )[:-1].to(histogram.device) + + 0.5 * bin_size + ) + boundaries_new_histogram = torch.linspace( + update_min, update_max, self.bins + 1, device=update_min.device + ).to(histogram.device) + # this maps the mid-poits of the histogram to the new histogram's space + bucket_assignments = ( + torch.bucketize(mid_points_histogram, boundaries_new_histogram, right=True) + - 1 + ) + # this then maps the histogram mid-points in the new space, weighted by the original histogram's values + # this is just the old histogram in the new histogram's space + + # In case due to numerical issues the values land higher/lower than the maximum/minimum + bucket_assignments[bucket_assignments >= self.bins] = self.bins - 1 + bucket_assignments[bucket_assignments < 0] = 0 + + update_histogram = torch.bincount( + bucket_assignments, weights=histogram, minlength=self.bins + ) + return update_histogram + + def _combine_histograms( + self, + orig_hist: torch.Tensor, + orig_min: torch.Tensor, + orig_max: torch.Tensor, + update_hist: torch.Tensor, + update_min: torch.Tensor, + update_max: torch.Tensor, + ) -> torch.Tensor: + # If the new min and max are the same as the current min and max, + # we can just add the new histogram to the original histogram + if update_min == orig_min and update_max == orig_max: + return orig_hist + update_hist + + # If the orig hist only has one value (i.e., the min and max are the same) + # we can just add it into new histogram + if orig_min == orig_max: + bin_value = torch.sum(update_hist) + transformed_orig_hist = ( + torch.histc(orig_min, bins=self.bins, min=update_min, max=update_max) # type: ignore[arg-type] + * bin_value + ) + return transformed_orig_hist + update_hist + + # We assume the update_hist is already in the target range, we will map the orig_max to it + assert update_min <= orig_min + assert update_max >= orig_max + + # Now we need to turn the old_histogram, into the range of the new histogram + transformed_orig_hist = self._upscale_histogram( + orig_hist, + orig_min, + orig_max, + update_min, + update_max, + ) + + return update_hist + transformed_orig_hist + + def reset_histogram( + self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor + ) -> None: + self.min_val.resize_(min_val.shape) + self.min_val.copy_(min_val) + self.max_val.resize_(max_val.shape) + self.max_val.copy_(max_val) + assert ( + min_val.numel() == 1 and max_val.numel() == 1 + ), "histogram min/max values must be scalar." + new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val) # type: ignore[arg-type] + self.histogram.detach_().resize_(new_histogram.shape) + self.histogram.copy_(new_histogram) + + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: # pyre-ignore[14] + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() + x_min, x_max = torch.aminmax(x) + # want to ignore torch.inf since we don't actually + # want to make our quantization range infinite + # and in practice those values will be clamped + if x_min == -torch.inf or x_max == torch.inf: + warnings.warn("torch.inf detected in input tensor, ignoring input") + x = x[x.abs() != torch.inf] + if x.numel() == 0: + return x_orig + x_min, x_max = torch.aminmax(x) + + current_min = self.min_val + current_max = self.max_val + + is_uninitialized = self.min_val == float("inf") or self.max_val == float("-inf") + if is_uninitialized: + self.reset_histogram(x, x_min, x_max) + else: + update_min, update_max = x_min, x_max + new_min = torch.min(current_min, update_min) + new_max = torch.max(current_max, update_max) + + # TODO: For some reason, this is required for it to pass torchscript test + # new_min and new_max should already have requires_grad set to False + new_min, new_max = new_min.detach(), new_max.detach() + update_histogram = torch.histc( + x, + self.bins, + min=new_min, + max=new_max, # type: ignore[arg-type] + ).to(self.histogram.device) + if new_min == current_min and new_max == current_max: + combined_histogram = self.histogram + update_histogram + self.histogram.detach_().resize_(combined_histogram.shape) + self.histogram.copy_(combined_histogram) + else: + combined_histogram = self._combine_histograms( + self.histogram, + current_min, + current_max, + update_histogram, + new_min, + new_max, + ) + self.histogram.detach_().resize_(combined_histogram.shape) + self.histogram.copy_(combined_histogram) + self.min_val.detach_().resize_(new_min.shape) + self.min_val.copy_(new_min) + self.max_val.detach_().resize_(new_max.shape) + self.max_val.copy_(new_max) + + return x_orig + + @torch.jit.export + def calculate_qparams(self): + is_uninitialized = self.min_val == float("inf") and self.max_val == float( + "-inf" + ) + if is_uninitialized: + warnings.warn( + "must run observer before calling calculate_qparams.\ + Returning default scale and zero point " + ) + return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor( + [0], device=self.min_val.device.type + ) + assert self.bins == len(self.histogram), ( + "The number of bins in histogram should be equal to the number of bins " + "supplied while making this observer" + ) + + new_min, new_max = self._non_linear_param_search() + + return self._calculate_qparams(new_min, new_max) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "min_val"] = self.min_val + destination[prefix + "max_val"] = self.max_val + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 3: + # if min_val and max_val are not initialized, update their shape + # to account for the differences between v2 and v3 + min_val_name, max_val_name = prefix + "min_val", prefix + "max_val" + if min_val_name in state_dict: + if state_dict[min_val_name].shape == torch.Size([0]): + state_dict[min_val_name] = torch.tensor(float("inf")) + if max_val_name in state_dict: + if state_dict[max_val_name].shape == torch.Size([0]): + state_dict[max_val_name] = torch.tensor(float("-inf")) + + local_state = ["min_val", "max_val"] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + setattr(self, name, val) + elif strict: + missing_keys.append(key) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + +class FixedQParamsObserver(ObserverBase): + r""" + Observer that simulates quantize and dequantize with fixed + quantization parameters in training time. Only per tensor + quantization is supported. + + Args: + `scale` (float): fixed scale for the observer + `zero_point` (int): fixed zero point for the observer + `dtype`, `qscheme`, `quant_min`, `quant_max` + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__( + self, + scale, + zero_point, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + quant_min=0, + quant_max=255, + is_dynamic=False, + **kwargs, + ): + if is_dynamic: + raise NotImplementedError( + "FixedQParamsObserver doesn't support dynamic quantization" + ) + super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) + self.quant_min = quant_min + self.quant_max = quant_max + self.register_buffer("scale", torch.tensor([scale], dtype=torch.float)) + self.register_buffer("zero_point", torch.tensor([zero_point], dtype=torch.int)) + self.dtype = dtype + self.qscheme = qscheme + + def forward(self, X): + return X + + @torch.jit.export + def calculate_qparams(self): + return self.scale, self.zero_point + + +class PlaceholderObserver(ObserverBase): + r""" + Observer that doesn't do anything and just passes its configuration to the + quantized module's ``.from_float()``. + + Can be used for quantization to float16 which doesn't require determining + ranges. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + quant_min: minimum value in quantized domain (TODO: align behavior with other observers) + quant_max: maximum value in quantized domain + custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation + (Can be used in Graph Mode Passes for special case ops). + compute_dtype (deprecated): if set, marks the future quantize function to use + dynamic quantization instead of static quantization. + This field is deprecated, use `is_dynamic=True` instead. + is_dynamic: if True, the `quantize` function in the reference model + representation taking stats from this observer instance will + use dynamic quantization. + """ + + def __init__( + self, + dtype=torch.float32, + custom_op_name="", + compute_dtype=None, + quant_min=None, + quant_max=None, + qscheme=None, + eps=None, + is_dynamic=False, + ) -> None: + super().__init__(dtype=dtype, is_dynamic=is_dynamic) + if qscheme is None: + qscheme = torch.per_tensor_affine + if eps is None: + eps = torch.finfo(torch.float32).eps + + # dtype of input of the target operator, e.g. for dynamic quantization + # ops, the dtype will be float32 + self.dtype = dtype + self.qscheme = qscheme + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.custom_op = custom_op_name + # used for configuration of computation type for dynamic quantization + if compute_dtype: + is_dynamic = True + warnings.warn( + "Please use `is_dynamic` instead of `compute_dtype`. \ + `compute_dtype` will be deprecated in a future release \ + of PyTorch." + ) + + def forward(self, x): + return x + + @torch.jit.export + def extra_repr(self): + return f"dtype={self.dtype}, is_dynamic={self.is_dynamic}" + + @torch.jit.export + def calculate_qparams(self): + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for PlaceholderObserver" + ) + + +class RecordingObserver(ObserverBase): + r""" + The module is mainly for debug and records the tensor values during runtime. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + """ + + __annotations__ = {"tensor_val": list[Optional[torch.Tensor]]} + + def __init__(self, dtype=torch.quint8): + super().__init__(dtype=dtype, is_dynamic=False) + self.tensor_val = [] + + def forward(self, x): + self.tensor_val.append(x.clone()) + return x + + @torch.jit.export + def calculate_qparams(self): + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for RecordingObserver" + ) + + @torch.jit.export + def get_tensor_value(self): + return self.tensor_val + + +class NoopObserver(ObserverBase): + r""" + Observer that doesn't do anything and just passes its configuration to the + quantized module's ``.from_float()``. + + Primarily used for quantization to float16 which doesn't require determining + ranges. + + Args: + dtype: Quantized data type + custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation + (Can be used in Graph Mode Passes for special case ops). + """ + + def __init__(self, dtype=torch.float16, custom_op_name="") -> None: + super().__init__(dtype=dtype, is_dynamic=False) + self.dtype = dtype + self.custom_op = custom_op_name + + def forward(self, x): + return x + + @torch.jit.export + def calculate_qparams(self): + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for NoopObserver" + ) + + +class ReuseInputObserver(ObserverBase): + r"""This observer is used when we want to reuse the observer from the operator + that produces the input Tensor, typically used for operators like reshape, e.g. + ``` + x0 = ... + x1 = x0.reshape() + ``` + if we configure x0 to be observed by some observer, let's say MinMaxObserver, + and reshape is configured with ReuseInputObserver, we'll reuse the observer instance + for x0 for x1 (output of reshape). If x0 is not observed, we also won't observe x1. + + Note: this is only enabled in FX Graph Mode Quantization + """ + + def __init__(self) -> None: + super().__init__(torch.quint8, is_dynamic=False) + + def forward(self, x): + return x + + @torch.jit.export + def calculate_qparams(self): + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for ReuseInputObserver" + ) + + +""" +# Experimental Affine Quantization Feature START +We plan to merge the following with torchao repo after we move pt2e flow to torchao +copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py +""" +from dataclasses import dataclass +from enum import Enum, auto + + +class MappingType(Enum): + """How floating point number is mapped to integer number + + symmetric mapping means floating point range is symmetrically mapped to integer range + let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4) + we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) + e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) + + SYMMETRIC_NO_CLIPPING_ERR is a variant of symmetric mapping, where the scale is the max of smin + and smax, where smin = min_val_neg / quant_min, and smax = max_val_pos / quant_max. By calculating + smin and smax individually, there can be less round error on negative values, and no out-of-range + of all floating point values. + + asymmetric mapping means we just directly map the floating point range to integer range, + for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter + based on this mapping + e.g. scale = (10.2 - (-3.5)) / (7 - (-8)) + """ + + SYMMETRIC = auto() + SYMMETRIC_NO_CLIPPING_ERR = auto() + ASYMMETRIC = auto() + + +class ZeroPointDomain(Enum): + """Enum that indicate whether zero_point is in integer domain or floating point domain + + integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) + float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + none domain: quantized_val = (float_val / scale) + """ + + INT = auto() + FLOAT = auto() + NONE = auto() + + +class TorchAODType(Enum): + """ + Placeholder for dtypes that do not exist in PyTorch core yet. + """ + + # torch.int1 to torch.int7 will be added to PyTorch 2.6 + # These will remain here for BC with older PyTorch versions + INT1 = auto() + INT2 = auto() + INT3 = auto() + INT4 = auto() + INT5 = auto() + INT6 = auto() + INT7 = auto() + + +@dataclass(frozen=True) +class Granularity: + """ + Base class for representing the granularity of quantization. + + This class serves as a parent for specific granularity types used in + quantization operations, such as per-tensor or per-axis quantization. + """ + + +@dataclass(frozen=True) +class PerBlock(Granularity): + """ + Represents per-block granularity in quantization. See + :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for + `block_size` + + Attributes: + block_size (Tuple[int, ...]): The size of each quantization group + """ + + block_size: tuple[int, ...] + + +@dataclass(frozen=True) +class PerTensor(Granularity): + """ + Represents per-tensor granularity in quantization. + + This granularity type calculates the quantization parameters + based off the entire tensor. + + """ + + +@dataclass(frozen=True) +class PerAxis(Granularity): + """ + Represents per-axis granularity in quantization. + + This granularity type calculates different quantization parameters + along a specified axis of the tensor. + + For example if the input tensor is shape [8, 16] and axis=0, then + the quantization parameters are calculated for each row of the tensor. + Giving a total of 8 quantization parameters. + + Attributes: + axis (int): The axis along which reduction is performed. + """ + + axis: int + + +@dataclass(frozen=True) +class PerGroup(Granularity): + """ + Represents per-channel group granularity in quantization. + + This granularity type calculates different quantization parameters + for each group of elements. + + For example if the input tensor is shape [8, 16], and the group size is 4, then + the input tensor is reshaped to [64, 4] + quantization parameters are calculated for each group of 4 elements, + giving a total of 64 quantization parameters. + + Attributes: + group_size (int): The size of each quantization group + + """ + + group_size: int + + +class PerRow(Granularity): + """ + Represents row-wise granularity in quantization. + + This is a special case of per-axis quantization and is unique to Float8 matmuls + where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight + is quantized with a block_size of (1, weight.shape[1]). + """ + + +class PerToken(Granularity): + """ + Represents per-token granularity in quantization. + + This granularity type calculates a different set of quantization parameters + for each token, which is represented as the last dimension of the tensor. + + For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens + with 4 elements each, and we will calculate 6 sets of quantization parameters, + one for each token. + + If the input tensor has only two dimensions, e.g. [8, 16], then this is + equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters. + """ + + +def get_block_size( + input_shape: tuple[int, ...], granularity: Granularity +) -> tuple[int, ...]: + """Get the block size based on the input shape and granularity type. + + Args: + input_shape: The input tensor shape possibly more than 2 dimensions + granularity: The granularity type of the quantization + """ + assert isinstance( + granularity, Granularity + ), "Please provide an instance of Granularity, not subclass of it" + if isinstance(granularity, PerTensor): + return input_shape + elif isinstance(granularity, PerAxis): + block_size = list(input_shape) + block_size[granularity.axis] = 1 + return tuple(block_size) + elif isinstance(granularity, PerRow): + return (1,) * (len(input_shape) - 1) + (input_shape[-1],) + elif isinstance(granularity, PerGroup): + assert ( + len(input_shape) == 2 + ), f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" + return (1, granularity.group_size) + elif isinstance(granularity, PerToken): + block_size = list(input_shape) + block_size[-1] = input_shape[-1] + return tuple(block_size) + raise ValueError(f"Unsupported Granularity: {granularity}") + + +class AffineQuantizedObserverBase(ABC, torch.nn.Module): + """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) + + Args: + `granularity` and `block_size`: The granularity of the quantization, + must specify at least one, if both are specified `block_size` takes precedence + Current supported granularity type are `PerTensor` and `PerAxis` + other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` + """ + + with_args = classmethod(_with_args) + + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + super().__init__() + assert granularity is not None, "granularity is None" + + self.mapping_type = mapping_type + self.target_dtype = target_dtype + self.granularity = granularity + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.scale_dtype = scale_dtype + self.zero_point_dtype = zero_point_dtype + self.preserve_zero = preserve_zero + self.zero_point_domain = zero_point_domain + # populatd during forward + self.block_size = None + self.original_dtype = None + + @abstractmethod + def forward(self, input: torch.Tensor) -> torch.Tensor: + """forward function should take the input tensor + and updates internal stats and return the original input Tensor + """ + + @abstractmethod + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + """Calculate quantization parameter based on the stats attached to the observer module + and returns a tuple of scale and zero_point Tensor + """ + + +def _is_observer_script_module(mod, obs_type_name): + """Returns true if given mod is an instance of Observer script module.""" + if isinstance(mod, torch.jit.RecursiveScriptModule): + # qualified name looks like '__torch__.torchao.quantization.pt2e_flow.observer.___torch_mangle_2.MinMaxObserver' + suffix = mod._c.qualified_name.split(".", 1)[1] + name = re.sub(r"\.___torch_mangle_\d+", "", suffix) + return obs_type_name in name + return False + + +# Experimental Affine Quantization Feature END + + +def _is_activation_post_process(module): + return isinstance( + module, + ( + torchao.quantization.pt2e_flow.ObserverBase, + torchao.quantization.pt2e_flow.FakeQuantizeBase, + AffineQuantizedObserverBase, + ), + ) or _is_observer_script_module(module, "quantization.observer") + + +def _is_per_channel_script_obs_instance(module): + if isinstance(module, torch.jit.RecursiveScriptModule): + return _is_observer_script_module( + module, "quantization.observer.PerChannelMinMaxObserver" + ) or _is_observer_script_module( + module, "quantization.observer.MovingAveragePerChannelMinMaxObserver" + ) + return False + + +def get_observer_state_dict(mod): + r""" + Returns the state dict corresponding to the observer stats. + Traverse the model state_dict and extract out the stats. + """ + od = OrderedDict() + if isinstance(mod, torch.jit.RecursiveScriptModule): + for k, v in mod.state_dict().items(): + if "observer" in k: + od[k] = v + else: + # path for GraphModule and nn.Module (eager mode) + for k, v in mod.state_dict().items(): + if "activation_post_process" in k: + od[k] = v + od._metadata = mod.state_dict()._metadata # type: ignore[attr-defined] + return od + + +def load_observer_state_dict(mod, obs_dict): + r""" + Given input model and a state_dict containing model observer stats, + load the stats back into the model. The observer state_dict can be saved + using torchao.quantization.pt2e_flow.get_observer_state_dict + """ + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + for name, module in mod.named_modules(): + prefix = name + "." + if _is_activation_post_process(module): + if _is_per_channel_script_obs_instance(module): + # For per-channel observers we need to call a custom load_from_state_dict to resize the tensor. + # However this is not called when the module is scripted and we end up calling the default one in module.py + module._load_from_state_dict_script( + obs_dict, prefix, {}, True, missing_keys, unexpected_keys, [] + ) + else: + module._load_from_state_dict( + obs_dict, prefix, {}, False, missing_keys, unexpected_keys, [] + ) + for k in missing_keys: + if "observer" in k or "activation_post_process" in k: + raise Exception( # noqa: TRY002 + f"Missing keys for observer {k} in state_dict" + ) + for k in unexpected_keys: + if "observer" in k or "activation_post_process" in k: + raise Exception( # noqa: TRY002 + f"Unexpected keys for observer {k} in state_dict" + ) + + +# Restrict activations to be in the range (0,127) +default_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127) +""" +Default observer for static quantization, usually used for debugging. +""" + +default_placeholder_observer = PlaceholderObserver +""" +Default placeholder observer, usually used for quantization to torch.float16. +""" + +default_debug_observer = RecordingObserver +""" +Default debug-only observer. +""" + +default_weight_observer = MinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_tensor_symmetric +) +""" +Default weight observer. +""" + +weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + quant_min=-127, + quant_max=127, + eps=2**-12, +) +""" +Symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +default_histogram_observer = HistogramObserver.with_args(quant_min=0, quant_max=127) +""" +Default histogram observer, usually used for PTQ. +""" + +default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric +) +""" +Default per-channel weight observer, usually used on backends where per-channel +weight quantization is supported, such as `fbgemm`. +""" + +per_channel_weight_observer_range_neg_127_to_127 = PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + quant_min=-127, + quant_max=127, + eps=2**-12, +) +""" +Per-channel, symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +default_dynamic_quant_observer = PlaceholderObserver.with_args( + dtype=torch.quint8, + quant_min=0, + quant_max=255, + is_dynamic=True, +) +""" +Default observer for dynamic quantization. +""" + +default_float_qparams_observer = PerChannelMinMaxObserver.with_args( + dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 +) +""" +Default observer for a floating point zero-point. +""" + +default_float_qparams_observer_4bit = PerChannelMinMaxObserver.with_args( + dtype=torch.quint4x2, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 +) +""" +Default observer for a floating point zero-point and 4 bit activations. +""" + +# TODO(future PR): remove these defaults and enforce activation functions +# to explicitly specify their output range +default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args( + scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255 +) +default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args( + scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255 +) +# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases +default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer +default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer + +""" +Default observers for fixed qparams operations. +""" + +default_reuse_input_observer = ReuseInputObserver +""" +Default observer for operators like reshape that reuses the observer of input to +the operator +""" diff --git a/torchao/quantization/pt2e_flow/pt2e/_affine_quantization.py b/torchao/quantization/pt2e_flow/pt2e/_affine_quantization.py new file mode 100644 index 0000000000..a8c6ee7b98 --- /dev/null +++ b/torchao/quantization/pt2e_flow/pt2e/_affine_quantization.py @@ -0,0 +1,775 @@ +# copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py +# and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py +# PLESE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC +import logging +from abc import ABCMeta +from typing import Any, Union + +import torch +from torch.fx import Node + +from torchao.quantization import ( + choose_qparams_affine_with_min_max, +) +from torchao.quantization.pt2e_flow.observer import ( + AffineQuantizedObserverBase, + TorchAODType, + get_block_size, +) + +ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + +logger = logging.getLogger(__name__) + +FP8_TYPES = { + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, +} +_SUB_BYTE_UINT_BOUNDS = { + torch.uint1: (0, 2**1 - 1), + torch.uint2: (0, 2**2 - 1), + torch.uint3: (0, 2**3 - 1), + torch.uint4: (0, 2**4 - 1), + torch.uint5: (0, 2**5 - 1), + torch.uint6: (0, 2**6 - 1), + torch.uint7: (0, 2**7 - 1), +} + +""" +Map from dtype to the bound value of integers +TODO: maybe can replace this with call to torch.iinfo +""" +_DTYPE_TO_QVALUE_BOUNDS: dict[Union[torch.dtype, TorchAODType], tuple[int, int]] = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int16: (-(2**15), 2**15 - 1), + torch.int32: (-(2**31), 2**31 - 1), +} +_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS) + + +def _is_float8_type(dtype: torch.dtype) -> bool: + fp8_types = { + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + } + return dtype in fp8_types + + +# TODO: decide on if we want to allow custom quant_min/quant_max here +def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): + """Get quant_min and quant_max args based on dtype and also + verify that they are within the range of possible quant_min/quant_max + for dtype + """ + if dtype in FP8_TYPES: + quant_min_lower_bound, quant_max_upper_bound = ( + torch.finfo(dtype).min, + torch.finfo(dtype).max, + ) + elif dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise ValueError(f"Unsupported dtype: {dtype}") + else: + quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] + if quant_min is None: + quant_min = quant_min_lower_bound + if quant_max is None: + quant_max = quant_max_upper_bound + + assert quant_min >= quant_min_lower_bound, ( + "quant_min out of bound for dtype, " + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + ) + + assert quant_max <= quant_max_upper_bound, ( + "quant_max out of bound for dtype, " + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + ) + return quant_min, quant_max + + +def _get_reduction_params(block_size, input_size): + """Given block_size and input size find the parameters for reduction: + + Output: + shape_for_reduction: the shape we use to `view` input to prepare it for reduction + reduction_dims: the dims we'll do reduction over + + Example:: + Input: + block_size: (3, 3, 2, 10) + input_size: (3, 3, 10, 10) + + Output: + shape_for_reduction: (3, 3, 5, 2, 10) + reduction_dim: [0, 1, 3, 4] + """ + assert len(block_size) == len(input_size) + shape_for_reduction = [] + reduction_dims = [] + cur_dim = 0 + for i in range(len(block_size)): + if block_size[i] != input_size[i] and block_size[i] > 1: + assert input_size[i] % block_size[i] == 0, ( + f"Expecting input size at {i} dimension: " + f"{input_size[i]} to be divisible by block_size at {i} dimension: {block_size[i]}" + ) + shape_for_reduction.append(input_size[i] // block_size[i]) + shape_for_reduction.append(block_size[i]) + # reduce over the block_size[i] dim + reduction_dims.append(cur_dim + 1) + cur_dim += 2 + else: + # block_size[i] == input_size[i] or block_size[i] == 1 + shape_for_reduction.append(input_size[i]) + # we only need to reduce over the dimension if block_size is greater than 1 + # otherwise it's already the same as reduced dimension + if block_size[i] != 1: + reduction_dims.append(cur_dim) + cur_dim += 1 + return shape_for_reduction, reduction_dims + + +def _register_custom_op(lib): + """This decorator is used to preserve some high level operators for torch.export.export + while still allow them to be decomposed for inductor path + + requirement: make sure `fn.__name__[1:]` is the operator name you want to register + + NOTE: This should be applied at the top, after all other decorators have been applied + NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input, + e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make + sense for downstream system (like executorch) to accept as well + + Example: + lib = torch.library.Library("my_namespace', "FRAGMENT") + + register_custom_op = _register_custom_op(lib) + + @register_custom_op + def _the_op_that_needs_to_be_preserved(...) + ... + + # after this, `_the_op_that_needs_to_be_preserved` will be preserved as + # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after + # torch.export.export / torch._export.export_for_training + + """ + from torch._inductor.decomposition import register_decomposition + + def decorator(fn): + from torch._library.infer_schema import infer_schema + + # expecting fn.__name__ starts with `_` and we want to take the rest + # to be the name of the custom op + assert ( + fn.__name__[0] == "_" + ), f"Expecting function name starts with `_`, got {fn.__name__}" + assert not any( + c in fn.__name__ for c in ".<>" + ), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" + op_name = fn.__name__[1:] + schema = op_name + infer_schema(fn, mutates_args={}) + lib.define(schema) + lib.impl(op_name, fn, "CompositeImplicitAutograd") + + lib_namespace = lib.ns + op = getattr(getattr(torch.ops, lib_namespace), op_name) + register_decomposition([op])(fn) + return op + + return decorator + + +# quant_lib = torch.library.Library("pt2e_quant", "FRAGMENT") # noqa: TOR901 + +# register_custom_op = _register_custom_op(quant_lib) + + +# def choose_qparams_affine_with_min_max( +# min_val: torch.Tensor, +# max_val: torch.Tensor, +# mapping_type: MappingType, +# block_size: tuple[int, ...], +# target_dtype: torch.dtype, +# quant_min: Optional[int] = None, +# quant_max: Optional[int] = None, +# eps: Optional[float] = None, +# scale_dtype: Optional[torch.dtype] = None, +# zero_point_dtype: Optional[torch.dtype] = None, +# preserve_zero: bool = True, +# zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, +# ) -> tuple[torch.Tensor, torch.Tensor]: +# """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine` +# operator that pass in min_val and max_val directly instead of deriving these from a single input. +# This is used for observers in static quantization where min_val and max_val may be obtained through +# tracking all the data in calibration data set. + +# Args: +# Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one +# difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val +# and then scale/zero_point, we pass in min_val/max_val directly +# """ +# return _choose_qparams_affine( +# None, +# mapping_type.name, +# block_size, +# target_dtype, +# quant_min, +# quant_max, +# eps, +# scale_dtype, +# zero_point_dtype, +# preserve_zero, +# zero_point_domain.name if zero_point_domain is not None else None, +# min_val, +# max_val, +# ) + + +# @register_custom_op +# def _choose_qparams_affine( +# input: Optional[torch.Tensor], +# mapping_type: str, +# block_size: List[int], +# target_dtype: torch.dtype, +# quant_min: Optional[Union[int, float, bool]] = None, +# quant_max: Optional[Union[int, float, bool]] = None, +# eps: Optional[float] = None, +# scale_dtype: Optional[torch.dtype] = None, +# zero_point_dtype: Optional[torch.dtype] = None, +# preserve_zero: bool = True, +# zero_point_domain: Optional[str] = "INT", +# min_val: Optional[torch.Tensor] = None, +# max_val: Optional[torch.Tensor] = None, +# ) -> tuple[torch.Tensor, torch.Tensor]: +# """op definition that has compatible signatures with custom op library + +# The op does the following: +# 1. figure out the dimension for reduction based on block_size +# 2. find min_val/max_val based on the dimension for reduction +# 3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero` +# and `zero_point_domain` +# """ +# quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) +# assert mapping_type in [ +# MappingType.SYMMETRIC.name, +# MappingType.SYMMETRIC_NO_CLIPPING_ERR.name, +# MappingType.ASYMMETRIC.name, +# ], f"Unsupported mapping type: {mapping_type}" +# if target_dtype in FP8_TYPES: +# assert ( +# mapping_type == MappingType.SYMMETRIC.name +# ), f"Only symmetric quantization is supported for FP8 types, got {mapping_type}" + +# if input is not None: +# if scale_dtype is None: +# scale_dtype = input.dtype +# if zero_point_dtype is None: +# zero_point_dtype = input.dtype +# if eps is None: +# eps = torch.finfo(input.dtype).eps + +# assert ( +# len(block_size) == input.dim() +# ), f"Got input dim:{input.dim()}, block_size: {block_size}" +# shape_for_reduction, reduction_dims = _get_reduction_params( +# block_size, input.size() +# ) +# input = input.view(shape_for_reduction) + +# min_val = torch.amin(input, dim=reduction_dims, keepdim=False) +# max_val = torch.amax(input, dim=reduction_dims, keepdim=False) +# else: +# assert ( +# min_val is not None and max_val is not None +# ), "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}" +# assert ( +# min_val.dtype == max_val.dtype +# ), "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}" + +# if scale_dtype is None: +# scale_dtype = min_val.dtype +# if zero_point_dtype is None: +# zero_point_dtype = min_val.dtype +# if eps is None: +# eps = torch.finfo(min_val.dtype).eps + +# if preserve_zero: +# min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) +# max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) +# else: +# min_val_neg = min_val +# max_val_pos = max_val + +# if ( +# mapping_type == MappingType.SYMMETRIC.name +# or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name +# ): +# # scales +# if mapping_type == MappingType.SYMMETRIC.name: +# max_val_pos = torch.max(-min_val_neg, max_val_pos) +# scale = max_val_pos / (float(quant_max - quant_min) / 2) +# else: +# assert mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name +# # calculate smin and smax individually and choose the larger one. For example, if quant_min = -8 and +# # quant_max = 7. +# # - If smin is bigger: There would be coverage on negative values down to -8, and less rounding +# # error than the existing SYMMETRIC case. +# # - If smax is bigger: it covers the positive values up to 7. The round +# # error may be bigger than the existing SYMMETRIC case. Either way, there's no out-of-range fp values after +# # quantization. +# smin = min_val_neg / float(quant_min) +# smax = max_val_pos / float(quant_max) +# mask = smin > smax +# scale = torch.where(mask, smin, smax) +# # zeros +# if not preserve_zero: +# raise ValueError( +# "preserve_zero == False is not supported for symmetric quantization" +# ) +# if ( +# zero_point_domain is not None +# and zero_point_domain != ZeroPointDomain.INT.name +# ): +# raise ValueError( +# "zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" +# ) +# scale = torch.clamp(scale, min=eps) +# zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) +# else: +# assert mapping_type == MappingType.ASYMMETRIC.name +# scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) +# scale = torch.clamp(scale, min=eps) +# if zero_point_domain == ZeroPointDomain.NONE.name: +# zero_point = None +# else: +# if preserve_zero: +# zero_point = quant_min - torch.round(min_val_neg / scale) +# zero_point = torch.clamp(zero_point, quant_min, quant_max) +# else: +# assert ( +# zero_point_domain == ZeroPointDomain.FLOAT.name +# ), "if not preserve_zero, zero_point must be in FLOAT domain" +# mid_point = (quant_max + quant_min + 1) / 2 +# zero_point = min_val_neg + scale * mid_point + +# if zero_point is not None: +# zero_point = zero_point.to(dtype=zero_point_dtype) +# return scale.to(dtype=scale_dtype), zero_point + + +# @torch.no_grad() +# def quantize_affine( +# input: torch.Tensor, +# block_size: tuple[int, ...], +# scale: torch.Tensor, +# zero_point: Optional[torch.Tensor], +# output_dtype: torch.dtype, +# quant_min: Optional[Union[int, float]] = None, +# quant_max: Optional[Union[int, float]] = None, +# zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, +# ) -> torch.Tensor: +# """ +# Args: +# input (torch.Tensor): original float32, float16 or bfloat16 Tensor +# block_size: (Tuple[int, ...]): granularity of quantization, +# this means the size of the tensor elements that's sharing the same qparam +# e.g. when size is the same as the input tensor dimension, we are using per tensor quantization +# scale (float): quantization parameter for affine quantization +# zero_point (int): quantization parameter for affine quantization +# output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor +# quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype +# quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype +# zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float +# if zero_point is in integer domain, zero point is added to the quantized integer value during +# quantization +# if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) +# value during quantization +# default is ZeroPointDomain.INT + +# Note: +# How can block_size represent different granularities? +# let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different +# granularities: + +# granularity type | block_size +# per_tensor | (3, 3, 10, 10) +# per_axis (axis=0) | (1, 3, 10, 10) +# per_axis (axis=1) | (3, 1, 10, 10) +# per_group (groupsize=2) | (3, 3, 10, 2) +# per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10) + + +# Output: +# quantized tensor with requested dtype +# """ +# return _quantize_affine( +# input, +# block_size, +# scale, +# zero_point, +# output_dtype, +# quant_min, +# quant_max, +# zero_point_domain.name if zero_point_domain is not None else None, +# ) + + +# @register_custom_op +# def _quantize_affine( +# input: torch.Tensor, +# block_size: List[int], +# scale: torch.Tensor, +# zero_point: Optional[torch.Tensor], +# output_dtype: torch.dtype, +# quant_min: Optional[Union[int, float, bool]] = None, +# quant_max: Optional[Union[int, float, bool]] = None, +# zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, +# ) -> torch.Tensor: +# """op definition that has compatible signatures with custom op library + +# Note: +# zero_point_domain is optional specifies how we quantize the floating point to quantized data: +# INT: quantized_val = (float_val / scale) (integer) + zero_point (integer) +# FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale +# None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization +# Where we do not want to round values to nearest integer and instead scale and cast. +# """ +# quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) +# # workaround for uintx dtypes, since we don't have native Uintx dtype connected with +# # torch.uintx dtypes yet +# if output_dtype in _SUB_BYTE_UINT_BOUNDS: +# output_dtype = torch.uint8 +# return _quantize_affine_no_dtype_cast( +# input, +# block_size, +# scale, +# zero_point, +# quant_min, +# quant_max, +# zero_point_domain, +# ).to(output_dtype) + + +# def _quantize_affine_no_dtype_cast( +# input: torch.Tensor, +# block_size: list[int], +# scale: torch.Tensor, +# zero_point: Optional[torch.Tensor], +# quant_min: Union[int, float], +# quant_max: Union[int, float], +# zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, +# ) -> torch.Tensor: +# """ +# The op does the following: +# 1. figure out the dimension for reduction based on block_size, also reshape the input to align with +# the shape after reduction +# 2. quantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain +# 3. reshape the quantized result to origianl shape +# """ +# # TODO: validations +# # TODO: validate scale/zero_point dimensions are compatible with block_size +# assert input.dtype in [ +# torch.float32, +# torch.float16, +# torch.bfloat16, +# ], f"Unsupported input dtype: {input.dtype}" +# assert ( +# len(block_size) == input.dim() +# ), f"Got input dim:{input.dim()}, block_size: {block_size}" +# shape_for_reduction, reduction_dims = _get_reduction_params( +# block_size, input.size() +# ) +# original_shape = input.shape +# input = input.view(shape_for_reduction) +# shape_after_reduction = shape_for_reduction +# for i in reduction_dims: +# shape_after_reduction[i] = 1 +# scale = scale.view(shape_after_reduction) +# if zero_point is not None: +# zero_point = zero_point.view(shape_after_reduction) + +# if zero_point_domain == ZeroPointDomain.INT.name: +# quant = torch.clamp( +# torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max +# ) +# elif zero_point_domain == ZeroPointDomain.NONE.name: +# assert ( +# zero_point is None +# ), "zero_point should be None when zero_point_domain is NONE" +# quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) +# elif zero_point_domain is None: +# # This case handles quantization for float8 we expect no zero point and no zero point domain +# assert ( +# zero_point is None +# ), "zero_point should be None when zero_point_domain is None" +# quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) +# else: +# assert zero_point_domain == ZeroPointDomain.FLOAT.name +# mid_point = (quant_max + quant_min + 1) / 2 +# min_val = zero_point - scale * mid_point +# quant = torch.clamp( +# torch.round((input - min_val) / scale), quant_min, quant_max +# ) +# quant = quant.view(original_shape) + +# return quant + + +# def dequantize_affine( +# input: torch.Tensor, +# block_size: tuple[int, ...], +# scale: torch.Tensor, +# zero_point: Optional[torch.Tensor], +# input_dtype: torch.dtype, +# quant_min: Optional[Union[int, float]] = None, +# quant_max: Optional[Union[int, float]] = None, +# zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, +# *, +# output_dtype: torch.dtype = torch.float32, +# ) -> torch.Tensor: +# """ +# Args: +# input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument +# block_size: (List[int]): granularity of quantization, +# this means the size of the tensor elements that's sharing the same qparam +# e.g. when size is the same as the input tensor dimension, we are using per tensor quantization +# scale (Tensor): quantization parameter for affine quantization +# zero_point (Tensor): quantization parameter for affine quantization +# input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor +# quant_min (Optional[int]): minimum quantized value for input Tensor +# quant_max (Optional[int]): maximum quantized value for input Tensor +# output_dtype (torch.dtype): dtype for output Tensor, default is fp32 +# zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float +# if zero_point is in integer domain, zero point is added to the quantized integer value during +# quantization +# if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) +# value during quantization +# default is ZeroPointDomain.INT + +# Output: +# dequantized Tensor, with requested dtype or fp32 +# """ +# return _dequantize_affine( +# input, +# block_size, +# scale, +# zero_point, +# input_dtype, +# quant_min, +# quant_max, +# zero_point_domain.name if zero_point_domain is not None else None, +# output_dtype=output_dtype, +# ) + + +# @register_custom_op +# def _dequantize_affine( +# input: torch.Tensor, +# block_size: List[int], +# scale: torch.Tensor, +# zero_point: Optional[torch.Tensor], +# input_dtype: torch.dtype, +# quant_min: Optional[Union[int, float, bool]] = None, +# quant_max: Optional[Union[int, float, bool]] = None, +# zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, +# output_dtype: torch.dtype = torch.float32, +# ) -> torch.Tensor: +# """op definition that has compatible signatures with custom op library""" +# # TODO: validate scale/zero_point dimensions are compatible with block_size +# if input_dtype not in _SUB_BYTE_UINT_BOUNDS: +# assert ( +# input.dtype == input_dtype +# ), f"Expected: {input_dtype}, got: {input.dtype}" +# assert output_dtype in [ +# torch.float32, +# torch.float16, +# torch.bfloat16, +# ], f"Unsupported output dtype: {output_dtype}" +# quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) +# return _dequantize_affine_no_dtype_check( +# input, +# block_size, +# scale, +# zero_point, +# quant_min, +# quant_max, +# zero_point_domain, +# output_dtype, +# ) + + +# def _dequantize_affine_no_dtype_check( +# input: torch.Tensor, +# block_size: list[int], +# scale: torch.Tensor, +# zero_point: Optional[torch.Tensor], +# quant_min: Union[int, float], +# quant_max: Union[int, float], +# zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, +# output_dtype: torch.dtype = torch.float32, +# ) -> torch.Tensor: +# """This function converts AQT tensors to their high precision floating point representation + +# The op does the following: +# 1. figure out the dimension for reduction based on block_size, also reshape the input to align with +# the shape after reduction +# 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain +# 3. reshape the quantized result to origianl shape and change dtype to the output_dtype +# """ +# assert ( +# len(block_size) == input.dim() +# ), f"Got input dim:{input.dim()}, block_size: {block_size}" +# shape_for_reduction, reduction_dims = _get_reduction_params( +# block_size, input.size() +# ) +# original_shape = input.shape +# input = input.view(shape_for_reduction) +# shape_after_reduction = shape_for_reduction +# for i in reduction_dims: +# shape_after_reduction[i] = 1 +# scale = scale.view(shape_after_reduction) + +# if zero_point is not None: +# zero_point = zero_point.view(shape_after_reduction) + +# if zero_point_domain == ZeroPointDomain.INT.name: +# # Force a copy to avoid input modification due +# # to upcoming in-place operations. +# dequant = input.to(torch.int32, copy=True) +# if zero_point is not None: +# dequant = dequant - zero_point.to(torch.int32) +# dequant = dequant.to(output_dtype) +# dequant = dequant * scale +# elif zero_point_domain == ZeroPointDomain.NONE.name: +# assert ( +# zero_point is None +# ), "zero_point should be None when zero_point_domain is NONE" +# dequant = input.to(output_dtype) +# dequant = dequant * scale +# elif zero_point_domain is None: +# # This case handles dequantization for float8 we expect no zero point and no zero point domain +# assert ( +# zero_point is None +# ), "zero_point should be None when zero_point_domain is None" +# assert _is_float8_type( +# input.dtype +# ), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" +# dequant = input.to(output_dtype) +# dequant = dequant * scale +# else: +# assert ( +# zero_point_domain == ZeroPointDomain.FLOAT.name +# ), f"Unexpected zero point domain: {zero_point_domain}" +# # TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this) +# mid_point = (quant_max + quant_min + 1) / 2 +# # This should allocate new memory and avoid input modification +# dequant = input - mid_point +# dequant = dequant.to(output_dtype) +# dequant *= scale +# if zero_point is not None: +# dequant += zero_point + +# return dequant.view(original_shape).to(output_dtype) + + +class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase): + def forward(self, input: torch.Tensor): + if input.numel() == 0: + return input + + input_detached = input.detach() + self.original_dtype = input_detached.dtype + assert self.granularity is not None, "granularity is None" + self.block_size = get_block_size(input_detached.shape, self.granularity) + + shape_for_reduction, reduction_dims = _get_reduction_params( + self.block_size, input_detached.size() + ) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + assert ( + self.min_val.shape == min_val.shape + ), f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + assert ( + self.max_val.shape == max_val.shape + ), f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + min_val = torch.min(self.min_val, min_val) + max_val = torch.max(self.max_val, max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + # returning original input + return input + + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + assert ( + hasattr(self, "min_val") and hasattr(self, "max_val") + ), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + [], # BlockSize is not needed because the min/max are already reduced + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + def convert(self, model: torch.fx.GraphModule, observer_node: Node): + from torch.ao.quantization.fx.utils import create_getattr_from_value + + scale, zero_point = self.calculate_qparams() + with model.graph.inserting_before(observer_node): + assert self.block_size is not None, "Expecting block_size to be populated" + assert ( + self.original_dtype is not None + ), "Expecting original_dtype to be populated" + scale_node = create_getattr_from_value(model, model.graph, "_scale", scale) + zero_point_node = create_getattr_from_value( + model, model.graph, "_zero_point", zero_point + ) + q_node = model.graph.call_function( + torch.ops.torchao_quant.quantize_affine, + ( + observer_node.args[0], + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {}, + ) + dq_node = model.graph.call_function( + torch.ops.torchao_quant.dequantize_affine, + ( + q_node, + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {"output_dtype": self.original_dtype}, + ) + observer_node.replace_all_uses_with(dq_node) + model.graph.erase_node(observer_node) diff --git a/torchao/quantization/pt2e_flow/pt2e/_numeric_debugger.py b/torchao/quantization/pt2e_flow/pt2e/_numeric_debugger.py new file mode 100644 index 0000000000..3e2badb452 --- /dev/null +++ b/torchao/quantization/pt2e_flow/pt2e/_numeric_debugger.py @@ -0,0 +1,342 @@ +import copy +import logging +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +from torch.ao.ns.fx.utils import compute_sqnr +from torch.export import ExportedProgram +from torch.fx import GraphModule, Node +from torch.nn import functional as F + +from torchao.quantization.pt2e_flow.pt2e.graph_utils import bfs_trace_with_node_process + +NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle" +CUSTOM_KEY = "custom" + +log = logging.getLogger(__name__) + + +def generate_numeric_debug_handle(ep: ExportedProgram) -> None: + """ + Attach numeric_debug_handle_id for all nodes in the graph module of the given + ExportedProgram, like conv2d, squeeze, conv1d, etc, except for placeholder. + Notice that nodes like getattr are out of scope since they are not in the graph. + + The graph nodes of input exported program are modified inplace. + + Here's an example of using debug handle quantize flow:: + + ep = export_for_training(eager_model, example_inputs) + generate_numeric_debug_handle(ep) + + m = ep.module() + quantizer = XNNPACKQuantizer() + m = prepare_pt2e(m, quantizer) + m = convert_pt2e(m) + """ + + # Sanity check the input data type + if not isinstance(ep, ExportedProgram): + raise ValueError( + f"Expected ep to be ExportedProgram, got {type(ExportedProgram)}" + ) + + unique_id = 0 + + def _find_max_id(node: torch.fx.Node) -> None: + nonlocal unique_id + unique_id = max( + unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0) + ) + + def _assign_debug_handle(node: torch.fx.Node) -> None: + nonlocal unique_id + if CUSTOM_KEY not in node.meta: + node.meta[CUSTOM_KEY] = {} + + if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]: + node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id + unique_id += 1 + + # Find the max ID that exists in the graph first, in case part of the graph + # has already been annotated. This way we guarantee there are no duplicate + # handle IDs. + bfs_trace_with_node_process(ep, _find_max_id) + + unique_id += 1 + + # Assign debug handles to all nodes in the graph that don't have one based on the + # max ID found in the previous step. + bfs_trace_with_node_process(ep, _assign_debug_handle) + + +def _detach(x: object) -> object: + detached: object = None + if isinstance(x, torch.Tensor): + detached = x.detach() + elif isinstance(x, (list, tuple)): + detached = type(x)([_detach(e) for e in x]) + elif isinstance(x, dict): + detached = {k: _detach(e) for k, e in x.items()} + else: + detached = x + return detached + + +def _tensor_shape_equals(x: object, y: object) -> bool: + if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): + return x.shape == y.shape + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + return all(_tensor_shape_equals(e1, e2) for e1, e2 in zip(x, y)) + elif isinstance(x, dict) and isinstance(y, dict): + all_equal = True + for k in x: + all_equal = all_equal and k in y and (_tensor_shape_equals(x[k], y[k])) + return all_equal + else: + log.debug("Comparing non Tensors: %s and %s, they must be equal", x, y) + return type(x) == type(y) and x == y + + +def _loss_fn( + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: object, y: object +) -> object: + """The returned loss will have the same structure as `x` and `y`, e.g. + if both are Tensor, we'll return a Tensor + if both are list, we'll return a list of Tensors + if both are dict, we'll return a dict with the same key, and value being the loss between the + two Tensors + """ + if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): + return loss(x.to(torch.float32), y.to(torch.float32)) + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + return type(x)([_loss_fn(loss, e1, e2) for e1, e2 in zip(x, y)]) + elif isinstance(x, dict) and isinstance(y, dict): + return {k: _loss_fn(loss, e, y[k]) for k, e in x.items()} + else: + return None + + +class OutputLogger(torch.nn.Module): + """ + Base class for capturing output values for nodes in a GraphModule, it only captures + Tensor output currently, but we can extend it to work for other types of inputs later if needed + """ + + # Mark as impure so that calls to it will not be removed during DCE. + _is_impure = True + + def __init__( + self, + debug_handle: int, + node_name: Optional[str] = None, + nn_module_stack: Optional[object] = None, + ) -> None: + super().__init__() + self.node_name = node_name + self.nn_module_stack = nn_module_stack + self.debug_handle = debug_handle + self.stats: list[object] = [] + + def forward(self, x: object) -> object: + self.stats.append(_detach(x)) + return x + + def __extra_repr__(self) -> str: + return ( + f"debug_handle={self.debug_handle}, node_name={self.node_name}, " + "nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})" + ) + + +def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node: + """For a given node, adds an OutputLogger that observes the output of that node, + and all its users use the OutputLogger output instead. + The OutputLogger will contain the debug_handle which can be used to compare + graphs after transforms""" + + # to avoid circular dep + from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix + + # add a logger after the node + with model.graph.inserting_after(node): + get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger") + logger_name = get_new_attr_name(model) + setattr( + model, + logger_name, + OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")), + ) + logger_node = model.graph.call_module(logger_name, (node,), {}) + + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is logger_node: + continue + user_node.replace_input_with(node, logger_node) + + return logger_node + + +def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: + """Add output loggers to node that has numeric_debug_handle + + Args: + model (GraphModule): original model + Returns: + a model with output loggers for all nodes that has numeric_debug_handle_id + """ + # don't change the original model + model = copy.deepcopy(model) + for n in model.graph.nodes: + if ( + CUSTOM_KEY not in n.meta + or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY] + ): + continue + numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] + _insert_logger(model, n, numeric_debug_handle) + + model.recompile() + return model + + +@dataclass(frozen=True) +class QuantizationComparisonResult: + actual: torch.Tensor + ref: torch.Tensor + + @property + def mse_loss(self) -> object: + return self.loss(F.mse_loss) + + @property + def sqnr(self) -> object: + return self.loss(compute_sqnr) + + def loss( + self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + ) -> object: + return _loss_fn(loss_function, self.actual, self.ref) + + def __repr__(self) -> str: + # Don't include the tensors themselves as they are quite large to print + # out. + return ( + f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})" + ) + + def __post_init__(self) -> None: + if not isinstance(self.actual, (torch.Tensor, list, tuple, dict)): + raise ValueError( + f"`self.actual` value must be a Tensor, list, tuple or dict, got: {self.actual}" + ) + + if not isinstance(self.ref, (torch.Tensor, list, tuple, dict)): + raise ValueError( + f"`self.ref` value must be a Tensor, list, tuple or dict, got: {self.ref}" + ) + + if not _tensor_shape_equals(self.ref, self.actual): + raise ValueError( + f"Cannot compare tensors with different shapes: ref={self.ref} vs actual={self.actual}" + ) + + +@dataclass(frozen=True) +class NodeAccuracySummary: + handle: int + actual_node_name: str + actual_module_stack: str + ref_node_name: str + ref_module_stack: str + results: Sequence[QuantizationComparisonResult] + + +def _module_stack_to_str(module_stack: object) -> str: + """Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear") + to "mod.foo.0.linear" + """ + if not isinstance(module_stack, dict): + return str(module_stack) + module_values_list = list(module_stack.values()) + if len(module_values_list) > 0: + owning_module = module_values_list[-1][0] + return str(owning_module) + else: + return str(module_stack) + + +def extract_results_from_loggers( + model: GraphModule, +) -> dict[int, tuple[Optional[str], object, list[object]]]: + """For a given model, extract the tensors stats and related information for each debug handle. + The reason we have a list of object, instead of Tensor is because the output of node may not be + a Tensor, it could be (nested) list, tuple or dict as well. + + Returns: + A dict is keyed by the debug_handle id and the values are a list of object recorded + in loggers + + """ + # Results maps debug handle to a tensor list for each model being compared. + handles: dict[int, tuple[Optional[str], object, list[object]]] = {} + for _name, module in model.named_children(): + if isinstance(module, OutputLogger) and len(module.stats) > 0: + handles[module.debug_handle] = ( + module.node_name, + module.nn_module_stack, + module.stats, + ) + + return handles + + +def compare_results( + ref_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]], + actual_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]], +) -> dict[int, NodeAccuracySummary]: + """Given two dict mapping from `debug_handle_id` (int) to list of tensors + return a map from `debug_handle_id` to `NodeAccuracySummary` that contains + comparison information like SQNR, MSE etc. + + Args: + ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id + actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id + + Returns: + Dict[int, NodeAccuracySummary] + """ + comparisons = {} + for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items(): + if debug_handle not in actual_results: + log.debug( + "Cannot compare for handle %s because it wasn't found in the transformed model", + debug_handle, + ) + continue + actual_name, actual_stack, actual_stats = actual_results[debug_handle] + try: + results = [ + QuantizationComparisonResult(actual=a, ref=b) + for a, b in zip(actual_stats, ref_stats) + ] + except Exception as e: + # Add extra information for an exception from QuantizationComparisonResult + # if the shapes didn't match, to include the handle and the node names. + raise ValueError( + f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}" + ) from e + + comparisons[debug_handle] = NodeAccuracySummary( + handle=debug_handle, + actual_node_name=actual_name or "", + actual_module_stack=_module_stack_to_str(actual_stack), + ref_node_name=ref_name or "", + ref_module_stack=_module_stack_to_str(ref_stack), + results=results, + ) + + return comparisons diff --git a/torchao/quantization/pt2e_flow/pt2e/convert.py b/torchao/quantization/pt2e_flow/pt2e/convert.py new file mode 100644 index 0000000000..a781d57d81 --- /dev/null +++ b/torchao/quantization/pt2e_flow/pt2e/convert.py @@ -0,0 +1,1400 @@ +# mypy: ignore-errors + +import copy +import operator +import warnings +from typing import Any, Callable, Optional, Union + +import torch +from torch.ao.quantization.backend_config import ( + BackendConfig, + get_native_backend_config, +) +from torch.ao.quantization.backend_config.utils import ( + get_fused_module_classes, + get_pattern_to_dtype_configs, + get_qat_module_classes, + get_root_module_to_quantized_reference_module, +) + +# importing the lib so that the quantized_decomposed ops are registered +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.ao.quantization.fx._equalize import ( + convert_eq_obs, + update_obs_for_equalization, +) +from torch.ao.quantization.fx.custom_config import ( + ConvertCustomConfig, + PrepareCustomConfig, +) +from torch.ao.quantization.fx.graph_module import ( + _is_observed_module, + _is_observed_standalone_module, +) +from torch.ao.quantization.fx.lower_to_fbgemm import lower_to_fbgemm +from torch.ao.quantization.fx.qconfig_mapping_utils import ( + _compare_prepare_convert_qconfig_mappings, + _generate_node_name_to_qconfig, + _is_qconfig_supported_by_dtype_configs, + _update_qconfig_for_fusion, + _update_qconfig_for_qat, +) +from torch.ao.quantization.fx.utils import ( + _get_module, + _is_custom_module_lstm, + _is_custom_module_mha, + assert_and_get_unique_device, + collect_producer_nodes, + create_getattr_from_value, + get_custom_module_class_keys, + graph_module_from_producer_nodes, + node_arg_is_weight, +) +from torch.ao.quantization.qconfig import QConfigAny, qconfig_equals +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quant_type import QuantType +from torch.ao.quantization.quantize import _remove_qconfig +from torch.ao.quantization.stubs import DeQuantStub +from torch.ao.quantization.utils import ( + _parent_name, + activation_is_statically_quantized, + get_qparam_dict, + get_swapped_custom_module_class, + is_per_channel, + to_underlying_dtype, + weight_is_quantized, +) +from torch.fx import GraphModule +from torch.fx.graph import Argument, Graph, Node +from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY +from torch.nn.utils.parametrize import type_before_parametrizations + +from torchao.quantization.pt2e_flow import CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY +from torchao.quantization.pt2e_flow.observer import _is_activation_post_process + +__all__ = [ + "convert", + "convert_custom_module", + "convert_standalone_module", + "convert_weighted_module", +] + +SUPPORTED_QDTYPES = [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.uint8, + torch.int8, + torch.uint16, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, +] + +_QSCHEME_TO_CHOOSE_QPARAMS_OP = { + torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, +} + + +def attach_preserved_attrs_to_model( + model: Union[GraphModule, torch.nn.Module], + preserved_attrs: dict[str, Any], +) -> None: + """Store preserved attributes to the model.meta so that it can be preserved during deepcopy""" + model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs) # type: ignore[operator, index, assignment] + # set the preserved attributes in the model so that user can call + # model.attr as they do before calling fx graph mode quantization + for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): # type: ignore[index, union-attr] + setattr(model, attr_name, attr) + + +def _check_is_graph_module(model: torch.nn.Module) -> None: + if not isinstance(model, GraphModule): + raise ValueError( + "input model must be a GraphModule, " + + "Got type:" + + str(type(model)) + + " Please make " + + "sure to follow the tutorials." + ) + + +def _replace_observer_with_quantize_dequantize_node_decomposed( + model: torch.fx.GraphModule, + node: Node, + modules: dict[str, torch.nn.Module], + node_name_to_scope: dict[str, tuple[str, type]], + node_name_to_qconfig: dict[str, QConfigAny], +) -> None: + """Replace activation_post_process module call node with quantize and + dequantize node working with decomposed Tensor + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) -> + torch.ops.quantized_decomposed.dequantize_per_tensor() -> ... + + or quantize_per_channel and dequantize_per_channel + """ + graph = model.graph + assert modules is not None + assert isinstance(node.target, str) + module_path, prefix = _get_module_path_and_prefix( + node, node_name_to_scope, node_name_to_qconfig + ) + activation_post_process = modules[node.target] + if hasattr(activation_post_process, "convert"): + activation_post_process.convert(model, node) + return + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all( + _has_none_qconfig(n, node_name_to_qconfig) + for n in list(node.args) + list(node.users.keys()) + ) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find corresponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment] + + def add_dequantize_op_kwargs(dequantize_op, input_node): + dequantize_op_kwargs = {} + if "val" in input_node.meta: + dq_out_dtype = input_node.meta["val"].dtype + if dq_out_dtype != torch.float32: + dequantize_op_kwargs = {"out_dtype": dq_out_dtype} + return dequantize_op_kwargs + + if dtype in SUPPORTED_QDTYPES and (not is_dynamic): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op: Optional[Callable] = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] + quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default + dequantize_op = ( + torch.ops.quantized_decomposed.dequantize_per_channel.default + ) + quant_min = activation_post_process.quant_min + quant_max = activation_post_process.quant_max + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_axis_": ch_axis, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_, + } + else: + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + scale = float(scale) + zero_point = int(zero_point) + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_, + } + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"] and ( + not isinstance(value_or_node, (float, int)) + ): + # For scale and zero_point values we register them as buffers in the root module. + # However, note that when the values are not tensors, as in the case of + # per_tensor quantization, they will be treated as literals. + # However, registering them as a node seems to cause issue with dynamo + # tracing where it may consider tensor overload as opposed to default. + # With extra check of scale and zero_point being scalar, it makes + # sure that the default overload can be used. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, graph, module_path + prefix + key, value_or_node + ) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + add_dequantize_op_kwargs(dequantize_op, input_node), + ) + + node.replace_all_uses_with(dequantized_node) + # propagate numeric debug handle from observer/fake_quant node to dequantize node + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + if CUSTOM_KEY not in dequantized_node.meta: + dequantized_node.meta[CUSTOM_KEY] = {} + dequantized_node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + CUSTOM_KEY + ][NUMERIC_DEBUG_HANDLE_KEY] + graph.erase_node(node) + elif is_dynamic: + # uint8/int8/fp16 dynamic quantization + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor + # we only use choose_qparams for is_decomposed now, + # but we should probably align the non-decomposed path with this as well, + # and that can be done after we remove reduce_range flag + # 1. extract qparams from activation_post_process module + dtype_ = to_underlying_dtype(dtype) + assert dtype_ in [torch.uint8, torch.int8], ( + "only uint8 and int8 are supported in reference flow for " + "dynamic quantization right now" + ) + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined] + eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps) # type: ignore[attr-defined] + # note: scale and zero_point are missing for quantize_per_tensor op + # we'll need to get this from choose_qparams op, which we'll add after + # this step + qparams = { + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_eps_": eps, + "_dtype_": dtype_, + } + + choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme] + # 2. insert choose_qparams op and update the qparams list + with graph.inserting_before(node): + input_node = node.args[0] + choose_qparams_op_inputs = [node.args[0]] + for key, value in qparams.items(): + # we have quant_min, quant_max and dtype, all should be stored + # as literals + choose_qparams_op_inputs.append(value) + choose_qparams_node = graph.create_node( + "call_function", choose_qparams_op, tuple(choose_qparams_op_inputs), {} + ) + # choose_qparms returns (scale, zero_point) + scale_node = graph.create_node( + "call_function", operator.getitem, (choose_qparams_node, 0), {} + ) + zero_point_node = graph.create_node( + "call_function", operator.getitem, (choose_qparams_node, 1), {} + ) + quant_min = qparams["_quant_min_"] + quant_max = qparams["_quant_max_"] + dtype = qparams["_dtype_"] + qparams = { + "_scale_": scale_node, + "_zero_point_": zero_point_node, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype, + } + + # 3. replace activation_post_process node to quantize and dequantize node + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"]: + # in this case we have a node in the graph since it's dynamically + # computed from the input, with choose_qparams op + qparam_node = value_or_node + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we + # store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + # need to use the tensor variant of this op, since scale and zero_point + # from choose_qparam are Tensors, instead of float/int, this is to + # prevent these nodes being traced away by downstream systems + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + add_dequantize_op_kwargs(dequantize_op, input_node), + ) + + node.replace_all_uses_with(dequantized_node) + # propagate numeric debug handle from observer/fake_quant node to dequantize node + if NUMERIC_DEBUG_HANDLE_KEY in node.meta: + dequantized_node.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + NUMERIC_DEBUG_HANDLE_KEY + ] + graph.erase_node(node) + elif dtype == torch.float16: + # Insert to_fp16 -> to_fp32 node + dtype_convert_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse + with graph.inserting_before(node): + input_node = node.args[0] + convert_fp16_node = graph.create_node( + "call_function", dtype_convert_op, (input_node, torch.float16), {} + ) + convert_fp32_node = graph.create_node( + "call_function", dtype_convert_op, (convert_fp16_node, torch.float), {} + ) + node.replace_all_uses_with(convert_fp32_node) + graph.erase_node(node) + + # should not reach since we have checks in the beginning to make sure the + # activation_post_process is supported + + +def _replace_observer_with_quantize_dequantize_node( + model: torch.fx.GraphModule, + node: Node, + modules: dict[str, torch.nn.Module], + node_name_to_scope: dict[str, tuple[str, type]], + node_name_to_qconfig: dict[str, QConfigAny], +) -> None: + """Replace activation_post_process module call node with quantize and + dequantize node + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... + """ + assert modules is not None + assert isinstance(node.target, str) + graph = model.graph + module_path, prefix = _get_module_path_and_prefix( + node, node_name_to_scope, node_name_to_qconfig + ) + activation_post_process = modules[node.target] + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all( + _has_none_qconfig(n, node_name_to_qconfig) + for n in list(node.args) + list(node.users.keys()) + ) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find corresponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] + + if dtype in [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] and (not is_dynamic): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + node_type = "call_function" + quantize_op: Optional[Callable] = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_axis_": ch_axis, + "_dtype_": dtype, + } + quantize_op = torch.quantize_per_channel + else: + scale = float(scale) + zero_point = int(zero_point) + qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} + quantize_op = torch.quantize_per_tensor + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"]: + # For scale and zero_point values we register them as buffers in the root module. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, graph, module_path + prefix + key, value_or_node + ) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif is_dynamic: + # uint8/int8/fp16 dynamic quantization branch + + node_type = "call_function" + quantize_op = torch.quantize_per_tensor_dynamic + # TODO: get reduce range from observer + # reduce_range = activation_post_process.reduce_range + reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") + qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range} + + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value in qparams.items(): + quantize_op_inputs.append(value) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif dtype == torch.float16: + node_type = "call_method" + quantize_op = "to" # type: ignore[assignment] + qparams = {"_dtype_": dtype} + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + quantize_op_inputs.append(value) + + quantized_node = graph.create_node( + node_type, quantize_op, tuple(quantize_op_inputs), {} + ) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + + # should not reach since we have checks in the beginning to make sure the + # activation_post_process is supported + + +# this is a temporary hack for custom module, we may want to implement +# this properly after the custom module class design is finalized +# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted +# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs +# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively. +def _replace_observer_or_dequant_stub_with_dequantize_node( + node: Node, graph: Graph +) -> None: + call_custom_module_node = node.args[0] + assert isinstance( + call_custom_module_node, Node + ), f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" + node.replace_all_uses_with(call_custom_module_node) + graph.erase_node(node) + _insert_dequantize_node(call_custom_module_node, graph) + + +def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool: + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] + + return ( + (dtype in SUPPORTED_QDTYPES and (not is_dynamic)) + or is_dynamic # type: ignore[return-value] + or dtype == torch.float16 + ) + + +def _has_none_qconfig( + node: Argument, node_name_to_qconfig: dict[str, QConfigAny] +) -> bool: + """Check if a node has a qconfig of None, i.e. user requested to not quantize + the node + """ + return ( + isinstance(node, Node) + and node.name in node_name_to_qconfig + and node_name_to_qconfig[node.name] is None + ) + + +def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None: + """Extract the subgraph that produces the weight for dynamic quant + or weight only quant node and run the subgraph to observe the weight. + Note that the observers of dynamic quant or weight only quant ops are + run during the convert step. + """ + for node in observed.graph.nodes: + if node.op != "call_function": + continue + for node_arg in node.args: + # node_arg is weight + if node_arg and node_arg_is_weight(node, node_arg): + weight_observer_nodes = collect_producer_nodes(node_arg) + if weight_observer_nodes is None: + continue + weight_observer_module = graph_module_from_producer_nodes( + observed, weight_observer_nodes + ) + # run the weight observer + weight_observer_module() + + +def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None: + """If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node, + we'll recursively remove the dequantize Node + """ + if isinstance(arg, Node) and arg.op == "call_method" and arg.target == "dequantize": + quantize_node = arg.args[0] + # we only replace the specific use since dequantize could be used by other nodes + # as well + node.replace_input_with(arg, quantize_node) + elif isinstance(arg, (list, tuple)): + for arg_element in arg: + _maybe_recursive_remove_dequantize(arg_element, node, graph) + elif isinstance(arg, dict): + for arg_element in arg.values(): + _maybe_recursive_remove_dequantize(arg_element, node, graph) + else: + warnings.warn( + f"Unsupported node type in recursive remove dequantize: {type(arg)}" + ) + + +def _get_module_path_and_prefix( + obs_node: Node, + node_name_to_scope: dict[str, tuple[str, type]], + node_name_to_qconfig: dict[str, QConfigAny], +) -> tuple[str, str]: + """Given and observer node, get the `Scope` or the fully qualified name for + the submodule containing the observed node, also return a prefix of "_input" + when the observed node is an input of a F.linear op, and not the output of another + quantized op. + TODO: this logic is hacky, we should think about how to remove it or make it more + general + """ + observed_node = obs_node.args[0] + # an observer can be inserted for both input of the next operator or output of the previous + # operator (they can be the same) + # this flag identifies if the observer is inserted only because the observed node is + # the input of the next operator + assert isinstance( + observed_node, Node + ), f"Expecting observed node to be a Node, but got {observed_node}" + is_input_observer_only = ( + node_name_to_qconfig[observed_node.name] is None + if observed_node.name in node_name_to_qconfig + else None + ) + if is_input_observer_only: + # if the quantize function is at the input of op, then we find the first user of the observer_node + # to get the path. If a linear call_function is in the user list, we return the first instance + # of linear node to get the FQN. + users = list(obs_node.users) + first_linear_use_or_first_use = users[0] if users else None + linear_node = None + for n in users: + if n.op == "call_function" and n.target == torch.nn.functional.linear: + linear_node = n + break + if linear_node: + first_linear_use_or_first_use = linear_node + prefix = "_input" + else: + # if the quantize function is at the output of the op, we use the observer input node to get the path + first_linear_use_or_first_use = observed_node + prefix = "" + + if ( + first_linear_use_or_first_use + and first_linear_use_or_first_use.name in node_name_to_scope + ): + module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name] + else: + # TODO: it's not used, so actually we can skip quantization + # but this requires changing return type of quantize_node + # we can fix it later if needed + module_path = "" + return module_path, prefix + + +def _insert_dequantize_node(node: Node, graph: Graph) -> None: + """Inserts dequantize node for `node` in `graph`""" + with graph.inserting_after(node): + dequantize_node = graph.call_method("dequantize", (node,)) + for user_node in dict(node.users): + if user_node is not dequantize_node: + user_node.replace_input_with(node, dequantize_node) + + +def _maybe_get_observer_for_node( + node: Node, modules: dict[str, torch.nn.Module] +) -> Optional[torch.nn.Module]: + """ + If the node is observed, return the observer + instance. Otherwise, return None. + """ + for maybe_obs_node in node.users.keys(): + if maybe_obs_node.op == "call_module": + maybe_obs = modules[str(maybe_obs_node.target)] + if _is_activation_post_process(maybe_obs): + return maybe_obs + return None + + +def convert_standalone_module( + node: Node, + modules: dict[str, torch.nn.Module], + model: torch.fx.GraphModule, + is_reference: bool, + backend_config: Optional[BackendConfig], +) -> None: + """Converts a observed standalone module to a quantized standalone module by calling + the fx convert api, currently using the same `is_reference` flag as parent, but we may + changing this behavior in the future (e.g. separating quantization and lowering for + standalone module as well) + + Args: + - node: The call_module node of the observed standalone module + - modules: named_module of original model + - model: original model + - is_reference: a flag from parent provided by user to decide if we want to + produce a reference model or a fbgemm/qnnpack model + - backend_config: backend configuration of the target backend of quantization + """ + # TODO: remove is_reference flag + if is_reference: + convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx + else: + convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined] + # We know that observed standalone module is a GraphModule since + # it's produced by us + observed_standalone_module: GraphModule = modules[str(node.target)] # type: ignore[assignment] + sm_input_quantized_idxs = observed_standalone_module.meta[ + "_observed_graph_module_attrs" + ].standalone_module_input_quantized_idxs + # remove the dequantize nodes for inputs + args = list(node.args) + for idx in range(len(args)): + if idx in sm_input_quantized_idxs: + arg = args[idx] + if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr] + quantize_node = arg.args[0] # type: ignore[union-attr] + node.replace_input_with(arg, quantize_node) + if len(arg.users) == 0: # type: ignore[union-attr] + model.graph.erase_node(arg) + # add dequantize node for output + sm_output_quantized_idxs = observed_standalone_module.meta[ + "_observed_graph_module_attrs" + ].standalone_module_output_quantized_idxs + if len(sm_output_quantized_idxs) > 0: + assert sm_output_quantized_idxs[0] == 0, "Currently only quantized" + "output idxs = [0] is supported" + + # if it's non-empty, then it means the output is kept in quantized form + # we'll just add a dequantize node after this node + _insert_dequantize_node(node, model.graph) + + # TODO: allow convert_custom_config to override backend_config + # for standalone module + quantized_standalone_module = convert_fn( + observed_standalone_module, backend_config=backend_config + ) + parent_name, name = _parent_name(node.target) + # update the modules dict + setattr(modules[parent_name], name, quantized_standalone_module) + modules[str(node.target)] = quantized_standalone_module + + +def convert_weighted_module( + node: Node, + modules: dict[str, torch.nn.Module], + observed_node_names: set[str], + node_name_to_qconfig: dict[str, QConfigAny], + backend_config: BackendConfig, + is_decomposed: bool = False, + is_reference: bool = False, +) -> None: + """Convert a weighted module to reference quantized module in the model + If the QConfig of a QAT module is not set, the module will still be converted to + a float module. + + Args: + - node: The call_module node of the observed standalone module + - modules: named_module of original model + - observed_node_names: names for the set of observed fx node, we can skip + this conversion if the node is not observed + """ + original_module = modules[str(node.target)] + qconfig: QConfigAny = original_module.qconfig # type: ignore[assignment] + weight_post_process = None + qat_module_classes = get_qat_module_classes(backend_config) + + if isinstance(original_module, qat_module_classes): + # Converting qat module to a float module, we need to attach + # weight fake_quant to the module, weight fake_quant is assumed to be run during + # QAT so we don't need to run it again here + weight_post_process = original_module.weight_fake_quant + original_module = original_module.to_float() # type: ignore[operator] + # change qat module to float module + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, original_module) + + is_observed = node.name in observed_node_names + # If a qconfig is not defined for this node, then skip converting to a reference module + if ( + qconfig is None + or _has_none_qconfig(node, node_name_to_qconfig) + or not is_observed + ): + return + + # skip converting to reference quantized module if the qconfig is not supported + pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) + dtype_configs = pattern_to_dtype_configs.get(type(original_module), []) + if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs): + return + + # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized + is_weight_quantized = weight_is_quantized(qconfig) + + # the condition for swapping the module to reference quantized module is: + # weights need to be quantized + if not is_weight_quantized: + return + + fused_module = None + float_module = original_module + # extract the individual float_module and fused module + if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule): + fused_module = float_module + float_module = fused_module[0] # type: ignore[index] + + # TODO: move this to the reference quantized module + # weight_qparams or weight_qparams dict + wq_or_wq_dict = {"is_decomposed": is_decomposed} + if isinstance(float_module, torch.nn.RNNCellBase): + weight_post_process_ih = qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process_hh = qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process_ih(float_module.weight_ih) + weight_post_process_hh(float_module.weight_hh) + weight_qparams_ih = get_qparam_dict(weight_post_process_ih) + weight_qparams_hh = get_qparam_dict(weight_post_process_hh) + wq_or_wq_dict.update( + { + "weight_ih": weight_qparams_ih, + "weight_hh": weight_qparams_hh, + } + ) + elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)): + # format for wq_or_wq_dict (flattened attributes): + # {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...} + for wn in float_module._flat_weights_names: + if hasattr(float_module, wn) and wn.startswith("weight"): + weight = getattr(float_module, wn) + weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] + if weight_post_process.dtype == torch.qint8: # type: ignore[union-attr] + weight_post_process(weight) # type: ignore[operator, misc] + wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process) + else: + # weight_post_process is None means the original module is not a QAT module + # we need to get weight_post_process from qconfig in this case + is_ptq = weight_post_process is None + if is_ptq: + weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] + device = assert_and_get_unique_device(float_module) + if device: + weight_post_process.to(device) + + # Call weight observer/fake_quant at least once to ensure the scales and zero points + # have the right shapes. Note: there are two cases where we don't have to do this: + # + # (1) QAT: The model's forward method already calls the weight observer/fake_quant, + # and this typically happens during training, so we don't need to do it here. + # + # (2) Non-reference (lowered) case: The quantized module's from_float method already + # calls the weight observer/fake_quant, so we don't have to do it here. + # + # Currently we ignore both cases and call the weight observer/fake_quant here + # regardless, which is technically incorrect. For (1), this is mainly to preserve BC + # in test code, which may not always train before convert. In the future, we should + # break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941. + # + # For PT2, however, we don't need to preserve BC here, so we can skip this hack + # for QAT. We identify this case as (is_decomposed + is_reference + is_qat). + # Note that we still need it for PTQ in the PT2 flow since the model's forward + # method doesn't call the weight observer. + is_qat = not is_ptq + if not (is_decomposed and is_reference and is_qat): + weight_post_process(float_module.weight) # type: ignore[operator] + + wq_or_wq_dict.update(get_qparam_dict(weight_post_process)) + + # We use the same reference module for all modes of quantization: static, dynamic, weight_only + # root_module_to_quantized_reference_module: module mapping from root (floating point) module class + # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d + root_module_to_quantized_reference_module = ( + get_root_module_to_quantized_reference_module(backend_config) + ) + ref_qmodule_cls = root_module_to_quantized_reference_module.get( + type_before_parametrizations(float_module), None + ) + assert ( + ref_qmodule_cls is not None + ), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" + ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined] + if fused_module is not None: + fused_module[0] = ref_qmodule # type: ignore[operator] + else: + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, ref_qmodule) + + +def _remove_previous_dequantize_in_custom_module( + node: Node, prev_node: Node, graph: Graph +) -> None: + """ + Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows: + + Before: quantize - dequantize - custom_module + After: quantize - custom_module + \\ - dequantize + """ + # expecting the input node for a custom module node to be a Node + assert isinstance( + prev_node, Node + ), f"Expecting the argument for custom module node to be a Node, but got {prev_node}" + if prev_node.op == "call_method" and prev_node.target == "dequantize": + node.replace_input_with(prev_node, prev_node.args[0]) + # Remove the dequantize node if it doesn't have other users + if len(prev_node.users) == 0: + graph.erase_node(prev_node) + + +def convert_custom_module( + node: Node, + graph: Graph, + modules: dict[str, torch.nn.Module], + custom_module_class_mapping: dict[QuantType, dict[type, type]], + statically_quantized_custom_module_nodes: set[Node], +) -> None: + """Converts an observed custom module to a quantized custom module based on + `custom_module_class_mapping` + For static quantization, we'll also remove the previous `dequantize` node and + attach the observer node for output to the module, the observer for the node + will be converted to a dequantize node instead of quantize-dequantize pairs + later in the graph. In the end we would have a quantized custom module that + has the same interface as a default quantized module in nn.quantized namespace, + i.e. quantized input and quantized output. + + Args: + - node: The call_module node of the observed standalone module + - graph: The graph containing the node + - modules: named_module of original model + - custom_module_class_mapping: mapping from observed custom module class to + quantized custom module class, used to swap custom modules + - statically_quantized_custom_module_nodes: we'll add the custom module node + if we find it is statically quantized, this will be used later when converting + observers to quant/dequant node pairs, if the observed node is a statically + quantized custom module nodes, we'll convert the observer to a dequantize node, + this is to keep the interface the same as the default quantized module. + TODO: maybe we want to redesign this part to align with reference model design + as well, but there has been some discussions around the interface, so we can do + it later. + """ + observed_custom_module = modules[str(node.target)] + qconfig = observed_custom_module.qconfig + if activation_is_statically_quantized(qconfig): + statically_quantized_custom_module_nodes.add(node) + if _is_custom_module_lstm(node, modules): + # The inputs are tuples in the form (input, (hidden0, hidden1)) + # Ensure all three input nodes are quantized + assert ( + len(node.args) == 2 + and isinstance(node.args[1], tuple) + and len(node.args[1]) == 2 + ) + (inputs, (hidden0, hidden1)) = node.args # type: ignore[misc] + assert isinstance(inputs, Node) + assert isinstance(hidden0, Node) + assert isinstance(hidden1, Node) + _remove_previous_dequantize_in_custom_module(node, inputs, graph) + _remove_previous_dequantize_in_custom_module(node, hidden0, graph) + _remove_previous_dequantize_in_custom_module(node, hidden1, graph) + elif _is_custom_module_mha(node, modules): + # Inputs are in the form (query, key, value) + # TODO: This is the first step in enabling the full fx custom module + # quantization path for MultiheadAttention, and only covers the inputs + # to the module. + # Additional handling is yet to be implemented for the outputs, similar + # to LSTM custom module + assert len(node.args) == 3 + query, key, value = node.args + assert isinstance(query, Node) + assert isinstance(key, Node) + assert isinstance(value, Node) + _remove_previous_dequantize_in_custom_module(node, query, graph) + _remove_previous_dequantize_in_custom_module(node, key, graph) + _remove_previous_dequantize_in_custom_module(node, value, graph) + else: + # remove the previous dequant node to ensure the inputs are quantized + arg = node.args[0] + assert isinstance(arg, Node) + _remove_previous_dequantize_in_custom_module(node, arg, graph) + # absorb the following observer into the module conversion + activation_post_process = _maybe_get_observer_for_node(node, modules) + assert activation_post_process is not None + observed_custom_module.activation_post_process = activation_post_process + + # swap the observed custom module to quantized custom module + quantized_custom_module_class = get_swapped_custom_module_class( + observed_custom_module, custom_module_class_mapping, qconfig + ) + quantized_custom_module = quantized_custom_module_class.from_observed( + observed_custom_module + ) + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, quantized_custom_module) + + +def convert( + model: GraphModule, + is_reference: bool = False, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, + is_standalone_module: bool = False, + _remove_qconfig_flag: bool = True, + qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, + is_decomposed: bool = False, + keep_original_weights: bool = False, +) -> GraphModule: + """ + We will convert an observed model (a module with observer calls) to a reference + quantized model, the rule is simple: + 1. for each observer module call in the graph, we'll convert it to calls to + quantize and dequantize functions based on the observer instance + 2. for weighted operations like linear/conv, we need to convert them to reference + quantized module, this requires us to know whether the dtype configured for the + weight is supported in the backend, this is done in prepare step and the result + is stored in observed_node_names, we can decide whether we need to swap the + module based on this set + + Args: + * `is_standalone_module`: when this flag is True, it means we are quantizing + a submodule that is not inlined in parent module, and will be quantized + separately as one unit. + + * `is_decomposed`: a boolean flag to indicate whether we want to use the + quantize operator for decomposed quantized tensor + (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone + quantized tensor (torch.quantize_per_tensor) + + Returns: + a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config, with + input_quantized_idxs, output_quantized_idxs, please + see docs for :func:`~torch.ao.quantization.prepare_fx` for details + """ + if convert_custom_config is None: + convert_custom_config = ConvertCustomConfig() + + if isinstance(convert_custom_config, dict): + warnings.warn( + "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " + "in a future version. Please pass in a ConvertCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) + + if isinstance(qconfig_mapping, dict): + warnings.warn( + "Passing a QConfig dictionary to convert is deprecated and will not be supported " + "in a future version. Please pass in a QConfigMapping instead.", + FutureWarning, + stacklevel=2, + ) + qconfig_mapping = ( + QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None + ) + qconfig_mapping = copy.deepcopy(qconfig_mapping) + assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping) + + if isinstance(backend_config, dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.", + FutureWarning, + stacklevel=2, + ) + backend_config = BackendConfig.from_dict(backend_config) + + if backend_config is None: + backend_config = get_native_backend_config() + + assert _is_observed_module(model), "incoming model must be produced by prepare_fx" + observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] + node_name_to_scope: dict[str, tuple[str, type]] = ( + observed_graph_module_attrs.node_name_to_scope + ) + prepare_custom_config: PrepareCustomConfig = ( + observed_graph_module_attrs.prepare_custom_config + ) + observed_node_names: set[str] = observed_graph_module_attrs.observed_node_names + node_name_to_qconfig: dict[str, QConfigAny] = ( + observed_graph_module_attrs.node_name_to_qconfig + ) # type: ignore[assignment] + + # mapping from fully qualified module name to module instance + # for example, + # { + # '': Model(...), + # 'linear': Linear(...), + # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), + # } + # We use remove_duplicate=False here because torch.cat uses + # the same activation_post_process module instance but different names + modules = dict(model.named_modules(remove_duplicate=False)) + + # TODO refactor this code once we update the prepare logic to have additional information on + # which graph nodes have been observed and share that with convert to decide which observers to ignore. + if qconfig_mapping: + prepare_qconfig_mapping: QConfigMapping = ( + observed_graph_module_attrs.qconfig_mapping + ) # type: ignore[assignment] + modules_copy = copy.deepcopy(modules) + + if observed_graph_module_attrs.is_qat: + _update_qconfig_for_qat(qconfig_mapping, backend_config) + _update_qconfig_for_fusion(model, qconfig_mapping) + + _compare_prepare_convert_qconfig_mappings( + prepare_qconfig_mapping, qconfig_mapping + ) # type: ignore[arg-type] + convert_node_name_to_qconfig = _generate_node_name_to_qconfig( + model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope + ) + # check the convert_node_name_to_qconfig generated and ensure that + # all the values either match what was set in prepare node_name_to_qconfig + # or are set to None in the convert_node_name_to_qconfig. + for k, v in node_name_to_qconfig.items(): + assert ( + k in convert_node_name_to_qconfig + ), f"Expected key {k} in convert node_name_to_qconfig" + if convert_node_name_to_qconfig[k] is not None: + assert qconfig_equals(v, convert_node_name_to_qconfig[k]), ( + f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " + f"but {v} was updated to {convert_node_name_to_qconfig[k]}" + ) + node_name_to_qconfig = convert_node_name_to_qconfig + + custom_module_classes = get_custom_module_class_keys( + convert_custom_config.observed_to_quantized_mapping + ) + custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping + + if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None: + # If we want to do equalization then do the following: + # Calculate the equalization scale, update the observers with the scaled + # inputs, and scale the weight + weight_eq_obs_dict = update_obs_for_equalization(model, modules) + convert_eq_obs(model, modules, weight_eq_obs_dict) + + # always run weight observers in the top level forward method + # for dynamic quant ops or weight only quant ops + _run_weight_observers(model, backend_config) + + # additional state to override inputs to be quantized, if specified + # by the user + placeholder_node_seen_cnt = 0 + input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes + + root_module_to_quantized_reference_module = ( + get_root_module_to_quantized_reference_module(backend_config) + ) + # convert tuples so that it can work with isinstance(module, tuple_of_classes) + root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) + qat_module_classes = get_qat_module_classes(backend_config) + fused_module_classes = get_fused_module_classes(backend_config) + statically_quantized_custom_module_nodes: set[Node] = set() + + for node in list(model.graph.nodes): + if node.op == "placeholder": + cur_placeholder_node_idx = placeholder_node_seen_cnt + placeholder_node_seen_cnt += 1 + if cur_placeholder_node_idx in input_quantized_idxs: + # Inputs are assumed to be quantized if the user specified the + # input_quantized_idxs override. + # we need to dequantize the inputs since all operators took + # floating point inputs in reference quantized models + _insert_dequantize_node(node, model.graph) + elif node.op == "output": + # If the argument is empty we don't need to do anything + if len(output_quantized_idxs) == 0: + continue + # Result are kept quantized if the user specified the + # output_quantized_idxs override. + # Remove the dequantize operator for the node in the end if any + return_node = node + output = node.args[0] + # outputs can be Node, list, tuple, dict, other cases are not supported yet + if isinstance(output, (list, tuple)): + for idx in output_quantized_idxs: + _maybe_recursive_remove_dequantize( + output[idx], return_node, model.graph + ) + elif isinstance(output, (Node, dict)): + # we treat dict as a single argument currently, but it can be extended + # to support {"key": dtype} after we change output_quantized_idxs to + # dict + if 0 in output_quantized_idxs: + _maybe_recursive_remove_dequantize(output, return_node, model.graph) + else: + warnings.warn( + f"Unsupported node type for output_quantized_idxs: {type(output)}" + ) + elif node.op == "call_module": + mod = _get_module(node, modules) + assert mod is not None + if _is_activation_post_process(mod): + observed_node = node.args[0] + if observed_node in statically_quantized_custom_module_nodes: + _replace_observer_or_dequant_stub_with_dequantize_node( + node, model.graph + ) + else: + if is_decomposed: + _replace_observer_with_quantize_dequantize_node_decomposed( + model, + node, + modules, + node_name_to_scope, + node_name_to_qconfig, + ) + else: + _replace_observer_with_quantize_dequantize_node( + model, + node, + modules, + node_name_to_scope, + node_name_to_qconfig, + ) + elif isinstance(mod, DeQuantStub): + _replace_observer_or_dequant_stub_with_dequantize_node( + node, model.graph + ) + elif _is_observed_standalone_module(mod): + convert_standalone_module( + node, modules, model, is_reference, backend_config + ) + # below this point `type_before_parametrizations` is used + # instead of `type` to handle situations with fx quant + sparsity + elif type_before_parametrizations(mod) in set(root_module_classes).union( + qat_module_classes + ).union(fused_module_classes): + # extra check for fused module classes to make sure they are fused module classes + # of target modules + if ( + type_before_parametrizations(mod) in fused_module_classes + and type_before_parametrizations(mod[0]) not in root_module_classes + ): # type: ignore[index] + continue + convert_weighted_module( + node, + modules, + observed_node_names, + node_name_to_qconfig, + backend_config, + is_decomposed, + is_reference, + ) + elif type_before_parametrizations(mod) in custom_module_classes: + convert_custom_module( + node, + model.graph, + modules, + custom_module_class_mapping, + statically_quantized_custom_module_nodes, + ) + + # remove deadcode after converting observers to quant/dequant ops + model.graph.eliminate_dead_code() + model = GraphModule(model, model.graph) + + # TODO: maybe move this to quantize_fx.py + if not is_reference: + model = lower_to_fbgemm( + model, node_name_to_qconfig, node_name_to_scope, keep_original_weights + ) + + # TODO: this looks hacky, we want to check why we need this and see if we can + # remove this + # removes qconfig and activation_post_process modules + if _remove_qconfig_flag: + _remove_qconfig(model) + model.delete_all_unused_submodules() + model.meta.pop("_observed_graph_module_attrs", None) + return model + + +def _convert_fx( + graph_module: GraphModule, + is_reference: bool, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, + is_standalone_module: bool = False, + _remove_qconfig: bool = True, + qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, + is_decomposed: bool = False, + keep_original_weights: bool = False, +) -> GraphModule: + """`is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`""" + if convert_custom_config is None: + convert_custom_config = ConvertCustomConfig() + + if isinstance(convert_custom_config, dict): + warnings.warn( + "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " + "in a future version. Please pass in a ConvertCustomConfig instead.", + FutureWarning, + stacklevel=3, + ) + convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) + + _check_is_graph_module(graph_module) + preserved_attr_names = convert_custom_config.preserved_attributes + preserved_attrs = { + attr: getattr(graph_module, attr) + for attr in preserved_attr_names + if hasattr(graph_module, attr) + } + + quantized = convert( + graph_module, + is_reference, + convert_custom_config, + is_standalone_module, + _remove_qconfig_flag=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + is_decomposed=is_decomposed, + keep_original_weights=keep_original_weights, + ) + + attach_preserved_attrs_to_model(quantized, preserved_attrs) + return quantized + + +def _convert_to_reference_decomposed_fx( + graph_module: GraphModule, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, + qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""Convert a calibrated or trained model to a reference quantized model, with + decomposed representation for quantized Tensor + see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details, + reference quantized model is a standard representation of a quantized model provided + by FX Graph Mode Quantization, it can be further lowered to run on the target + hardware, like accelerators + + Note: this is not public API + + Args: + * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend. See + :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + Return: + A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model) + + """ + torch._C._log_api_usage_once( + "quantization_api.quantize_fx._convert_to_reference_decomposed_fx" + ) + return _convert_fx( + graph_module, + is_reference=True, + convert_custom_config=convert_custom_config, + _remove_qconfig=False, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + is_decomposed=True, + ) diff --git a/torchao/quantization/pt2e_flow/pt2e/duplicate_dq_pass.py b/torchao/quantization/pt2e_flow/pt2e/duplicate_dq_pass.py new file mode 100644 index 0000000000..1d62651b34 --- /dev/null +++ b/torchao/quantization/pt2e_flow/pt2e/duplicate_dq_pass.py @@ -0,0 +1,82 @@ +# mypy: allow-untyped-defs +import logging +import operator + +import torch +from torch.fx.node import map_arg +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +from torchao.quantization.pt2e_flow.pt2e.utils import ( + _filter_sym_size_users, + _is_valid_annotation, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +__all__ = ["DuplicateDQPass"] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +def _maybe_duplicate_dq( + gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node +): + annotation = user.meta.get("quantization_annotation", None) + if not _is_valid_annotation(annotation): + return + with gm.graph.inserting_after(dq_node): + new_node = gm.graph.node_copy(dq_node) + + def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: + if n == dq_node: + return new_node + else: + return n + + new_args = map_arg(user.args, maybe_replace_node) + new_kwargs = map_arg(user.kwargs, maybe_replace_node) + user.args = new_args # type: ignore[assignment] + user.kwargs = new_kwargs # type: ignore[assignment] + + +class DuplicateDQPass(PassBase): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in _DEQUANTIZE_OPS: + dq_users = _filter_sym_size_users(node) + if len(dq_users) <= 1: + continue + # Do not duplicate dq for dynamic quantization + # Pattern: choose_qparam - getitem - q - dq + q_node = node.args[0] + if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS: + getitem_node = q_node.args[1] + if ( + isinstance(getitem_node, torch.fx.node.Node) + and getitem_node.op == "call_function" + and getitem_node.target == operator.getitem + ): + choose_qparam_node = getitem_node.args[0] + if ( + isinstance(choose_qparam_node, torch.fx.node.Node) + and choose_qparam_node.op == "call_function" + and choose_qparam_node.target + == torch.ops.quantized_decomposed.choose_qparams.tensor + ): + continue + for user in dq_users: + _maybe_duplicate_dq(graph_module, node, user) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/torchao/quantization/pt2e_flow/pt2e/export_utils.py b/torchao/quantization/pt2e_flow/pt2e/export_utils.py new file mode 100644 index 0000000000..8ab28c8f75 --- /dev/null +++ b/torchao/quantization/pt2e_flow/pt2e/export_utils.py @@ -0,0 +1,240 @@ +# mypy: allow-untyped-defs +import types + +import torch +import torch.nn.functional as F + +from torchao.quantization.pt2e_flow.utils import _assert_and_get_unique_device + +__all__ = [ + "model_is_exported", +] + +_EXPORTED_TRAINING_ATTR = "_exported_training" + + +class _WrapperModule(torch.nn.Module): + """Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you + are trying to export a callable. + """ + + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, *args, **kwargs): + """Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`.""" + return self.fn(*args, **kwargs) + + +def model_is_exported(m: torch.nn.Module) -> bool: + """ + Return True if the `torch.nn.Module` was exported, False otherwise + (e.g. if the model was FX symbolically traced or not traced at all). + """ + return isinstance(m, torch.fx.GraphModule) and any( + "val" in n.meta for n in m.graph.nodes + ) + + +def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool): + """ + Switch dropout patterns in the model between train and eval modes. + + Dropout has different behavior in train vs eval mode. For exported models, + however, calling `model.train()` or `model.eval()` does not automatically switch + the dropout behavior between the two modes, so here we need to rewrite the aten + dropout patterns manually to achieve the same effect. + + See https://github.com/pytorch/pytorch/issues/103681. + """ + # Avoid circular dependencies + from .utils import _get_aten_graph_module_for_pattern + + # Needed to ensure subgraph matches are self-contained + m.graph.eliminate_dead_code() + m.recompile() + + for inplace in [False, True]: + + def dropout_train(x): + return F.dropout(x, p=0.5, training=True, inplace=inplace) + + def dropout_eval(x): + return F.dropout(x, p=0.5, training=False, inplace=inplace) + + example_inputs = (torch.randn(1),) + if train_to_eval: + match_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_train), + example_inputs, + ) + replacement_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_eval), + example_inputs, + ) + else: + match_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_eval), + example_inputs, + ) + replacement_pattern = _get_aten_graph_module_for_pattern( + _WrapperModule(dropout_train), + example_inputs, + ) + + from torch.fx.subgraph_rewriter import replace_pattern_with_filters + + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + match_filters=[], + ignore_literals=True, + ) + m.recompile() + + +def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool): + """ + Switch batchnorm patterns in the model between train and eval modes. + + Batchnorm has different behavior in train vs eval mode. For exported models, + however, calling `model.train()` or `model.eval()` does not automatically switch + the batchnorm behavior between the two modes, so here we need to rewrite the aten + batchnorm patterns manually to achieve the same effect. + """ + # TODO(Leslie): This function still fails to support custom momentum and eps value. + # Enable this support in future updates. + + # Avoid circular dependencies + from .utils import _get_aten_graph_module_for_pattern + + # Needed to ensure subgraph matches are self-contained + m.graph.eliminate_dead_code() + m.recompile() + + def bn_train( + x: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ): + return F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True + ) + + def bn_eval( + x: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ): + return F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False + ) + + example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + device = _assert_and_get_unique_device(m) + is_cuda = device is not None and device.type == "cuda" + bn_train_aten = _get_aten_graph_module_for_pattern( + _WrapperModule(bn_train), + example_inputs, + is_cuda, + ) + bn_eval_aten = _get_aten_graph_module_for_pattern( + _WrapperModule(bn_eval), + example_inputs, + is_cuda, + ) + + if train_to_eval: + match_pattern = bn_train_aten + replacement_pattern = bn_eval_aten + else: + match_pattern = bn_eval_aten + replacement_pattern = bn_train_aten + + from torch.fx.subgraph_rewriter import replace_pattern_with_filters + + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + match_filters=[], + ignore_literals=True, + ) + m.recompile() + + +# TODO: expose these under this namespace? +def _move_exported_model_to_eval(model: torch.fx.GraphModule): + """ + Move an exported GraphModule to eval mode. + + This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm. + QAT users should call this before performing inference on the model. + + This call is idempotent; if the model is already in eval mode, nothing will happen. + """ + is_training = getattr(model, _EXPORTED_TRAINING_ATTR, True) + if not is_training: + return model + setattr(model, _EXPORTED_TRAINING_ATTR, False) + _replace_dropout(model, train_to_eval=True) + _replace_batchnorm(model, train_to_eval=True) + return model + + +def _move_exported_model_to_train(model: torch.fx.GraphModule): + """ + Move an exported GraphModule to train mode. + + This is equivalent to model.train() but only for certain special ops like dropout, batchnorm. + QAT users should call this before performing training on the model. + + This call is idempotent; if the model is already in train mode, nothing will happen. + """ + is_training = getattr(model, _EXPORTED_TRAINING_ATTR, False) + if is_training: + return model + setattr(model, _EXPORTED_TRAINING_ATTR, True) + _replace_dropout(model, train_to_eval=False) + _replace_batchnorm(model, train_to_eval=False) + return model + + +def _allow_exported_model_train_eval(model: torch.fx.GraphModule): + """ + Allow users to call `model.train()` and `model.eval()` on an exported model, + but with the effect of changing behavior between the two modes limited to special + ops only, which are currently dropout and batchnorm. + + Note: This does not achieve the same effect as what `model.train()` and `model.eval()` + does in eager models, but only provides an approximation. In particular, user code + branching on `training` flag will not function correctly in general because the branch + is already specialized at export time. Additionally, other ops beyond dropout and batchnorm + that have different train/eval behavior will also not be converted properly. + """ + + def _train(self, mode: bool = True): + if mode: + _move_exported_model_to_train(self) + else: + _move_exported_model_to_eval(self) + + def _eval(self): + _move_exported_model_to_eval(self) + + model.train = types.MethodType(_train, model) # type: ignore[method-assign] + model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] + return model diff --git a/torchao/quantization/pt2e_flow/pt2e/graph_utils.py b/torchao/quantization/pt2e_flow/pt2e/graph_utils.py new file mode 100644 index 0000000000..7a4b2309e3 --- /dev/null +++ b/torchao/quantization/pt2e_flow/pt2e/graph_utils.py @@ -0,0 +1,180 @@ +# mypy: allow-untyped-defs +import itertools +import operator +from collections import OrderedDict +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union + +import torch +from torch.export import ExportedProgram +from torch.fx import Node +from torch.fx.passes.utils.source_matcher_utils import ( + SourcePartition, + check_subgraphs_connected, + get_source_partitions, +) + +__all__ = [ + "find_sequential_partitions", + "get_equivalent_types", + "update_equivalent_types_dict", + "bfs_trace_with_node_process", +] + +_EQUIVALENT_TYPES: list[set] = [ + {torch.nn.Conv1d, torch.nn.functional.conv1d}, + {torch.nn.Conv2d, torch.nn.functional.conv2d}, + {torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d}, + {torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_}, + {torch.nn.BatchNorm2d, torch.nn.functional.batch_norm}, + {torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_}, + {torch.add, operator.add, operator.iadd, "add", "add_"}, + {torch.mul, operator.mul, operator.imul, "mul", "mul_"}, +] + + +def _create_equivalent_types_dict(): + _DICT = {} + for values in _EQUIVALENT_TYPES: + for v in values: + _DICT[v] = list(values) + return _DICT + + +_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() + + +def get_equivalent_types() -> list[set]: + return _EQUIVALENT_TYPES + + +def update_equivalent_types_dict(customized_equivalent_types=None): + """Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT. + When customized_equivalent_types passes in, + re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT. + """ + if customized_equivalent_types is None: + raise ValueError("customized_equivalent_types should not be None") + global _EQUIVALENT_TYPES + global _EQUIVALENT_TYPES_DICT + _EQUIVALENT_TYPES = customized_equivalent_types + _EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() + + +def _partitions_sequential(partitions: Sequence[SourcePartition]): + prev_partition = None + for partition in partitions: + if prev_partition is not None and not check_subgraphs_connected( + prev_partition, partition + ): + return False + prev_partition = partition + return True + + +def _get_matching_types(partition_type): + matching_types = [partition_type] + if partition_type in _EQUIVALENT_TYPES_DICT: + matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type]) + return matching_types + + +def _valid_type_sequence(partition_types: list[Any]): + partition_types_set = set() # type: ignore[var-annotated] + for partition_type in partition_types: + matching_types = _get_matching_types(partition_type) + matching_types_set = set(matching_types) + if len(partition_types_set & matching_types_set) > 0: + return False + partition_types_set |= matching_types_set + return True + + +def find_sequential_partitions( + gm: torch.fx.GraphModule, + partition_types: list[Any], + include_functional_equivalent=True, + filter_fn: Optional[Callable[[Node], bool]] = None, +): + if not _valid_type_sequence(partition_types): + raise ValueError( + f"Invalid partition types: {partition_types}. Each type in the sequence must be unique" + ) + + typed_partitions: OrderedDict[Any, list[SourcePartition]] = OrderedDict() + for partition_type in partition_types: + types_to_match = _get_matching_types(partition_type) + partitions = get_source_partitions(gm.graph, types_to_match, filter_fn) + typed_partitions[partition_type] = list( + itertools.chain.from_iterable(partitions.values()) + ) + + typed_partitions_list = list(typed_partitions.values()) + fusion_candidates = itertools.product(*typed_partitions_list) + fused_partitions = [ + candidate + for candidate in fusion_candidates + if _partitions_sequential(candidate) + ] + return fused_partitions + + +def _get_submodule( + graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int +) -> tuple[str, torch.nn.Module, torch.fx.Node]: + submod_node = node.args[arg_index] + assert isinstance(submod_node, torch.fx.Node) + assert submod_node.op == "get_attr" + assert isinstance(submod_node.target, str) + submodule = graph_module.get_submodule(submod_node.target) + # pyre-ignore + return submod_node.target, submodule, node + + +def _get_control_flow_submodules( + graph_module: torch.fx.GraphModule, +) -> list[tuple[str, torch.nn.Module, torch.fx.Node]]: + """ + Returns a list of submodules used for control flow operations + (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look + into submodules). Specifically, the returned value is a list containing a + tuple of (name of the submodule that's stored in the graph module, the + submodule itself, and the fx node that uses this submodule). + """ + control_flow_submodules = [] + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + if node.target is torch.ops.higher_order.cond: + control_flow_submodules.append(_get_submodule(graph_module, node, 1)) + control_flow_submodules.append(_get_submodule(graph_module, node, 2)) + if node.target is torch.ops.higher_order.map_impl: + control_flow_submodules.append(_get_submodule(graph_module, node, 0)) + + return control_flow_submodules + + +def bfs_trace_with_node_process( + model: Union[ExportedProgram, torch.fx.GraphModule], node_op: Callable +) -> None: + """Traverse the graph module and apply node_op to each node.""" + + assert isinstance( + model, (ExportedProgram, torch.fx.GraphModule) + ), f"Expected GraphModule or ExportedProgram, got {type(model)}" + gm = model.graph_module if isinstance(model, ExportedProgram) else model + queue = [gm] + while queue: + current_graph_module = queue.pop(0) + for node in current_graph_module.graph.nodes: + if node.op in ["output", "placeholder"]: + continue + + node_op(node) + + control_flow_submodules = [ + submodule + for _, submodule, _ in _get_control_flow_submodules(current_graph_module) + ] + queue.extend(control_flow_submodules) diff --git a/torchao/quantization/pt2e_flow/pt2e/port_metadata_pass.py b/torchao/quantization/pt2e_flow/pt2e/port_metadata_pass.py new file mode 100644 index 0000000000..68000b8261 --- /dev/null +++ b/torchao/quantization/pt2e_flow/pt2e/port_metadata_pass.py @@ -0,0 +1,215 @@ +# mypy: allow-untyped-defs +import logging +from typing import Optional + +import torch +from torch._export.error import InternalError +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +from torchao.quantization.pt2e_flow.pt2e.utils import ( + _filter_sym_size_users, + _find_q_dq_node_for_user, + _is_valid_annotation, +) +from torchao.quantization.pt2e_flow.quantizer import QuantizationSpecBase + +logger = logging.getLogger(__name__) +logger.setLevel(logging.ERROR) + +__all__ = ["PortNodeMetaForQDQ"] + +_METADATA_TO_PORT = [ + "stack_trace", + "quantization_tag", +] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + +_CHOOSE_QPARAMS_OPS = [ + torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, +] + + +def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None: + from_meta = from_node.meta + for meta_name in _METADATA_TO_PORT: + if meta_name in from_meta: + to_node.meta[meta_name] = from_meta[meta_name] + + +def _has_quant_annotation(node: torch.fx.Node) -> bool: + return "quantization_annotation" in node.meta + + +def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: + # BFS to look for choose qparams + from collections import deque + + queue = deque(list(node.users.keys())) + while len(queue): + n = queue.popleft() + if n.op == "output": + continue + if n.op == "call_function" and n.target in _CHOOSE_QPARAMS_OPS: + return n + for k in n.users.keys(): + queue.append(k) + return None + + +def _port_metadata_for_input_quant_nodes( + input_node: torch.fx.Node, + node: torch.fx.Node, + qspec: Optional[QuantizationSpecBase], +): + if qspec is None: + return + + is_dynamic_quant = getattr(qspec, "is_dynamic", None) + if is_dynamic_quant is not None and is_dynamic_quant is True: + choose_qparams_node = _find_choose_qparams_node(input_node) + if choose_qparams_node is None: + raise ValueError(f"No chose qparams node found for {node}") + choose_qparam_users = _filter_sym_size_users(choose_qparams_node) + if len(choose_qparam_users) != 2: + raise InternalError(f"Expecting exactly two user for {choose_qparams_node}") + scale_node = choose_qparam_users.pop() + dynamic_q_node = next(iter(scale_node.users.keys())) + dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node) + if len(dynamic_q_node_users) > 1: + raise InternalError(f"Expecting single user for {dynamic_q_node}") + dynamic_dq_node = dynamic_q_node_users.pop() + _add_metadata(choose_qparams_node, node) + _add_metadata(dynamic_q_node, node) + _add_metadata(dynamic_dq_node, node) + else: + q_node, dq_node = _find_q_dq_node_for_user(input_node, node) + if q_node is None or dq_node is None: + return + # add metadata for all the node between q_node and get_attr node + # if the q_node can be traced back to get_attr node + q_to_get_attr_nodes = [q_node] + q_node_input = q_node.args[0] + while ( + isinstance(q_node_input, torch.fx.Node) + and q_node_input.op == "call_function" + and q_node_input.target + in [ + torch.ops.aten.flatten.using_ints, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dim, + torch.ops.aten.transpose.Dimname, + torch.ops.aten.transpose.int, + torch.ops.aten.transpose_, + torch.ops.aten.view_copy.default, + torch.ops.aten.view.default, + torch.ops.aten._mkldnn_transpose, + ] + ): + q_to_get_attr_nodes.append(q_node_input) + q_node_input = q_node_input.args[0] + if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr": + for n in q_to_get_attr_nodes: + _add_metadata(n, q_node_input) + _add_metadata(dq_node, node) + + +def _port_metadata_for_output_quant_nodes( + node: torch.fx.Node, qspec: Optional[QuantizationSpecBase] +): + if qspec is None: + return + + node_users = _filter_sym_size_users(node) + if len(node.users) == 0: + return + if len(node_users) != 1: + logger.warning(f"Expecting {node} to have single user") # noqa: G004 + q_node = node_users.pop() + if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS: + logger.warning( + f"Expecting {node} user to be a quantized op but got {q_node}" # noqa: G004 + ) # noqa: G004 + return + + _add_metadata(q_node, node) + + +class PortNodeMetaForQDQ(PassBase): + """ + Port metadata for nodes added by quantization flow. + For static quant these are: + - quantizer_per_tensor.default, dequantize_per_tensor.default + - quantizer_per_channel.default, dequantize_per_channel.default + For dynamic quant these are: + - choose_qparams.tensor + - quantizer_per_tensor.tensor, dequantize_per_tensor.tensor + - quantizer_per_channel.default, dequantize_per_channel.default + + Rules of porting metadata: + - Metadata to be ported: + - nn_module_stack + - stack_trace + - quantization_tag + - Metadata to NOT be ported: + - Everything else + - Rules: + - Statically quantized patterns: + - Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node. + - Quantize nodes on the outputs inherit metadata of the producer node. + - Example 1: + - Original: [Conv -> AvgPool -> Linear] + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] + - Inner brackets specify which nodes Q/DQ inherit metdata from + - [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ] + - Note first Q and last DQ do not inherit metadata from any nodes + - Example 2: + - Original: [Conv -> AvgPool -> Linear] + - AvgPool is not quantized + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] + - Inner brackets specify which nodes Q/DQ inherit metdata from + - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ] + - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because + AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation + on the nodes (in this case AvgPool node) to conclude if the node or patter was + supposed to be quantized. And subsequntly decide if the preceding Q, if any, should + inherit metadata from AvgPool. + - Dynamically quantized patterns: + - Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes + - For example, below linear is dynamically quantized while rest statically: + - Original: [Conv -> AvgPool -> Linear] + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear] + - Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]] + - Note first Q does not inherit metadata from any nodes + NB: + - The best place for porting metadata is during observer conversion to q/dq. This is because it precisely + knows which quantization spec is converted to q/dq and thus from where the metadata should be ported. + However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit. + Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant + code, this pass should like to be integrated in the refactored variant of "convert" step. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + annotation = node.meta.get("quantization_annotation", None) + if _is_valid_annotation(annotation): + input_qspec_map = node.meta["quantization_annotation"].input_qspec_map + output_qspec = node.meta["quantization_annotation"].output_qspec + for input_node, qspec in input_qspec_map.items(): + _port_metadata_for_input_quant_nodes(input_node, node, qspec) + _port_metadata_for_output_quant_nodes(node, output_qspec) + return PassResult(graph_module, True) diff --git a/torchao/quantization/pt2e_flow/pt2e/prepare.py b/torchao/quantization/pt2e_flow/pt2e/prepare.py new file mode 100644 index 0000000000..5f8f8cd3bb --- /dev/null +++ b/torchao/quantization/pt2e_flow/pt2e/prepare.py @@ -0,0 +1,662 @@ +# mypy: allow-untyped-defs +import copy +from dataclasses import asdict +from typing import Any, Optional, Union + +import torch +from torch._subclasses import FakeTensor +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.fx.custom_config import PrepareCustomConfig +from torch.ao.quantization.fx.prepare import ( + _insert_obs_or_fq, + _save_state, +) +from torch.ao.quantization.qconfig import QConfigAny +from torch.fx import Graph, GraphModule, Node +from torch.fx.node import Argument + +from torchao.quantization.pt2e_flow import ( + CUSTOM_KEY, + NUMERIC_DEBUG_HANDLE_KEY, + ObserverOrFakeQuantize, + _DerivedObserverOrFakeQuantize, +) +from torchao.quantization.pt2e_flow.fake_quantize import ( + FixedQParamsFakeQuantize, +) +from torchao.quantization.pt2e_flow.observer import ( + FixedQParamsObserver, + _is_activation_post_process, + _PartialWrapper, +) +from torchao.quantization.pt2e_flow.quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + FixedQParamsQuantizationSpec, + QuantizationSpec, + QuantizationSpecBase, + SharedQuantizationSpec, +) + +# TODO: make pt2e folder private? +__all__ = [ + "prepare", +] + + +def _is_activation_post_process_node( + node: Node, named_modules: dict[str, torch.nn.Module] +) -> bool: + return ( + isinstance(node, torch.fx.Node) + and node.op == "call_module" + and _is_activation_post_process(named_modules[str(node.target)]) + ) + + +def _get_observer_kwargs( + quant_spec: Union[QuantizationSpec, FixedQParamsQuantizationSpec], +): + kwargs_dict = asdict(quant_spec) + return copy.deepcopy(kwargs_dict) + + +def _create_obs_or_fq_from_qspec( + quantization_spec: Optional[QuantizationSpecBase], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +): + """Create observer or fake quantize objects based on quantization spec + + Args: + quantization_spec: used to store parameters to create the observer or fake quantizer + obs_or_fq_map: this is a map from edge/output to the corresponding observer/fake_quant + instance, it may be reused for different edge/output depending on configuration + """ + if quantization_spec is None: + return None + if isinstance(quantization_spec, SharedQuantizationSpec): + edge_or_node = quantization_spec.edge_or_node + assert edge_or_node in obs_or_fq_map, ( + "please make sure only refer to edge or node that has " + f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}" + ) + return obs_or_fq_map[edge_or_node] + elif isinstance(quantization_spec, DerivedQuantizationSpec): + # can't use asdict, so not calling get_observer_kwargs here + kwargs = { + "dtype": quantization_spec.dtype, + "derive_qparams_fn": quantization_spec.derive_qparams_fn, + "quant_min": quantization_spec.quant_min, + "quant_max": quantization_spec.quant_max, + "qscheme": quantization_spec.qscheme, + "ch_axis": quantization_spec.ch_axis, + } + edge_or_nodes = quantization_spec.derived_from + obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes] + kwargs["obs_or_fqs"] = obs_or_fqs + return _DerivedObserverOrFakeQuantize.with_args(**kwargs)() + elif isinstance(quantization_spec, FixedQParamsQuantizationSpec): + kwargs = _get_observer_kwargs(quantization_spec) + observer_ctr = FixedQParamsObserver.with_args(**kwargs) + if is_qat: + return FixedQParamsFakeQuantize.with_args(observer=observer_ctr)() + else: + return observer_ctr() + + assert isinstance( + quantization_spec, QuantizationSpec + ), f"Expected QuantizationSpec got: {quantization_spec}" + observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr + kwargs = _get_observer_kwargs(quantization_spec) + kwargs.pop("observer_or_fake_quant_ctr") + # we will remove is_dynamic from QuantizationSpec because + # it seems that dynamic range quantization + obs_or_fq_class = observer_or_fake_quant_ctr + if isinstance(observer_or_fake_quant_ctr, _PartialWrapper): + obs_or_fq_class = observer_or_fake_quant_ctr.p.func # type: ignore[union-attr, assignment] + if "PerChannel" not in obs_or_fq_class.__name__: # type: ignore[operator, union-attr] + kwargs.pop("ch_axis") + return observer_or_fake_quant_ctr.with_args(**kwargs)() + + +def _find_root_edge_or_node( + edge_or_node: EdgeOrNode, shared_with_map: dict[EdgeOrNode, EdgeOrNode] +) -> EdgeOrNode: + """Find the root node for the sharing tree + Args: + edge_or_node: edge/node that we want to find the root + shared_with_map: each edge/node points to the parent, the root node will points to itself + + Returns: + root edge/node + """ + parent = shared_with_map[edge_or_node] + if parent == edge_or_node: + return edge_or_node + root = _find_root_edge_or_node(parent, shared_with_map) + # path compression + shared_with_map[edge_or_node] = root + return root + + +def _union( + parent: EdgeOrNode, + child: EdgeOrNode, + shared_with_map: dict[EdgeOrNode, EdgeOrNode], +) -> None: + """Merge the subtree for `child` with `parent`, the order is important here""" + root_parent = _find_root_edge_or_node(parent, shared_with_map) + root_child = _find_root_edge_or_node(child, shared_with_map) + # union the two trees by pointing the root of child to root of parent + shared_with_map[root_child] = root_parent + + +def _update_shared_with( + child: EdgeOrNode, + qspec: QuantizationSpecBase, + shared_with_map: dict[EdgeOrNode, EdgeOrNode], +): + """Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec` + configuration and established the relationship between `edge_or_node` with the edge/node that it + is pointing to, we'll use this information in the end to get the group id + """ + if isinstance(qspec, SharedQuantizationSpec): + parent = qspec.edge_or_node + # we point from edge_or_node to the node that it is sharing_with, e.g. + # qspec for a = SharedQuantizationSpec(b) means `a` points to `b` + _union(parent, child, shared_with_map) + + +def _unwrap_shared_qspec( + qspec: QuantizationSpecBase, + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], + shared_with_map: dict[EdgeOrNode, EdgeOrNode], +) -> QuantizationSpecBase: + """Unwraps qspec to get the final root qspec (non SharedQuantizationSpec) + if qspec is SharedQuantizationSpec + (1). tries to find the root edge or node for the node that the qspec points to + (2). recursively find the root qspec based on the qspec for the root node + """ + if isinstance(qspec, SharedQuantizationSpec): + sharing_with = qspec.edge_or_node + root = _find_root_edge_or_node(sharing_with, shared_with_map) + qspec = edge_or_node_to_qspec[root] + return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) + return qspec + + +def _has_same_attr( + qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase, attr_name: str +): + return ( + hasattr(qspec_a, attr_name) + and hasattr(qspec_b, attr_name) + and getattr(qspec_a, attr_name) == getattr(qspec_b, attr_name) + ) or (not hasattr(qspec_a, attr_name) and not hasattr(qspec_b, attr_name)) + + +def _get_edge_or_node_to_qspec( + model: torch.fx.GraphModule, +) -> dict[EdgeOrNode, QuantizationSpecBase]: + """Get a map from EdgeOrNode to quantization spec based on annotations on the nodes""" + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase] = {} + for n in model.graph.nodes: + if hasattr(n, "meta") and "quantization_annotation" in n.meta: + qa = n.meta["quantization_annotation"] + for input_to_n, qspec in qa.input_qspec_map.items(): + input_edge = (input_to_n, n) + edge_or_node_to_qspec[input_edge] = qspec + if qa.output_qspec is not None: + output_node = n + qspec = qa.output_qspec + edge_or_node_to_qspec[output_node] = qspec + return edge_or_node_to_qspec + + +def _union_input_edge_with( + input_edge, + input_edge_root_qspec, + edge_or_node, + edge_or_node_to_qspec, + shared_with_map, +): + """Union input edge with another edge or node, used in implicit sharing to point the current input + edge to other user edges of the producer node, or the output of producer node since these are + referring to the same Tensor + """ + root_qspec = None + if edge_or_node in edge_or_node_to_qspec: + qspec = edge_or_node_to_qspec[edge_or_node] + root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) + # TODO: add assertions for types of root qspecs + if root_qspec is not None and all( + _has_same_attr(root_qspec, input_edge_root_qspec, attr) + for attr in [ + "dtype", + "is_dynamic", + "quant_min", + "quant_max", + "qscheme", + "ch_axis", + "scale", + "zero_point", + ] + ): + # the input arg to the node should reuse the existing output observer for arg + # since dtype is the same (we may want to extend this to be a more strict check + # in the future) + # so we point from `input_edge` to `arg` (output of the argument) + _union(edge_or_node, input_edge, shared_with_map) + + +def _get_edge_or_node_to_group_id( + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], +) -> dict[EdgeOrNode, int]: + """Map from edge/node to the group ID, generated from quantization annotations, + edge/node with the same group ID should use the same observer/fake_quant instance + + This is applying SharedQuantizationSpec configuration and map each edge/node to a group + There is another implicit sharing that's built in the quantization, when we have the following: + * op1 -> op2 + * output of op1: int8_qspec + * (op1 -> op2) input edge: int8_qspec + we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor. + + Figuring out the correct group ID for all edge/node is a standard union find problem: + https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/ + + Args: + edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations + Returns: + edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that + belongs to the same group should have the same id + + Example: + op2 -> cat1 -> cat2 + op1 / / + op3 + edge_or_node_to_qspec: { + op1: int8_qspec, + op2: int8_qspec, + (op1, cat1): int8_qspc, + (op2, cat1): SharedQuantizationSpec((op1, cat1)), + cat1: SharedQuantizationSpec((op1, cat1)), + (op3, cat2): int8_qspec, + (cat1, cat2): SharedQuantizationSpec((op3, cat2)), + cat2: SharedQuantizationSpec((op3, cat2)), + } + + edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) + edge_or_node_to_group_id: { + op1: 1, + op2: 1, + (op1, cat1): 1, + (op2, cat1): 1, + cat1: 1, + (op3, cat2): 1, + (cat1, cat2): 1, + cat2: 1, + } + # everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which + # connects the two sharing group around cat1 and cat2 op due to transitive sharing + """ + # means the observer of key should be shared with observer with value, by default it will + # be shared with itself + shared_with_map: dict[EdgeOrNode, EdgeOrNode] = { + k: k for k in edge_or_node_to_qspec.keys() + } + for edge_or_node, qspec in edge_or_node_to_qspec.items(): + if isinstance(edge_or_node, torch.fx.Node): + output_node = edge_or_node + _update_shared_with(output_node, qspec, shared_with_map) + else: + input_edge = edge_or_node + input_edge_root_qspec = _unwrap_shared_qspec( + qspec, edge_or_node_to_qspec, shared_with_map + ) + + assert isinstance(input_edge, tuple) + arg, n = input_edge + if n.meta["quantization_annotation"].allow_implicit_sharing: + # NOTE: the order is important here, we first share with other users and then share with previous + # output because the reverse order could cause circular dependency + # e.g node1 -> node2 + # \ -> node3 + # when processing (node1, node2), if we first point (node1, node2) to node1 + # Step 1. shared_map = {(node1, node2): node1} + # Step 2. after that, we point the (node1, node2) to its other user (node1, node3) , + # which means shared_map = {(node1, node2): node1, node1: (node1, node3)} + # because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3) + # Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll + # have a circular dependency + # the following order works around this issue, but this does not allow arbitrary configuration + # of sharing so it might break in a different case in the future, when it breaks + # quantizer writer can check the notes here to debug the issue + + # sharing with other users of the producer node + # (arg, user) + if not isinstance(arg, Node) or not isinstance(n, Node): + raise Exception( # noqa: TRY002 + f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}" + ) + for user in arg.users: + if user is n: + continue + arg_to_user_edge = (arg, user) + _union_input_edge_with( + input_edge, + input_edge_root_qspec, + arg_to_user_edge, + edge_or_node_to_qspec, + shared_with_map, + ) + + # sharing with output of producer node + _union_input_edge_with( + input_edge, + input_edge_root_qspec, + arg, + edge_or_node_to_qspec, + shared_with_map, + ) + + _update_shared_with(input_edge, qspec, shared_with_map) + + # now that we get the sharing relations between all edges and nodes, we can assingn group ids + cur_group_id = 0 + edge_or_node_to_group_id: dict[EdgeOrNode, int] = {} + for edge_or_node in shared_with_map.keys(): + root = _find_root_edge_or_node(edge_or_node, shared_with_map) + if root not in edge_or_node_to_group_id: + edge_or_node_to_group_id[root] = cur_group_id + cur_group_id += 1 + edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root] + + return edge_or_node_to_group_id + + +def _get_obs_or_fq_map( + edge_or_node_to_group_id: dict[EdgeOrNode, int], + edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase], + is_qat: bool, +) -> dict[EdgeOrNode, ObserverOrFakeQuantize]: + """Generates the EdgeOrNode to observer/fake_quant instances + Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant + instances + """ + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {} + group_id_to_obs_or_fq: dict[int, ObserverOrFakeQuantize] = {} + for edge_or_node, qspec in edge_or_node_to_qspec.items(): + group_id = edge_or_node_to_group_id[edge_or_node] + if group_id not in group_id_to_obs_or_fq: + # TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify + # the implementation for _create_obs_or_fq_from_qspec + group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec( + qspec, obs_or_fq_map, is_qat + ) + obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id] + return obs_or_fq_map + + +def _maybe_insert_input_observer_for_arg_or_kwarg( + node: Union[Node, Any], + arg: Argument, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Argument: + """ + Given a `node` and an `arg`, inserts an input observer between + `node` and `arg` if necessary. + """ + # for ops such as torch.cat([x0, x1]), + # traverse through the list + if isinstance(arg, (list, tuple)): + new_arg_to_return = [] + for inner_arg in arg: + new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + inner_arg, + qconfig, + model, + named_modules, + obs_or_fq_map, + is_qat, + ) + new_arg_to_return.append(new_inner_arg) + return type(arg)(new_arg_to_return) + + if not isinstance(arg, Node): + return arg + assert isinstance(arg, Node) + # default (no observer) + new_arg = arg + + # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes + original_arg = arg + while _is_activation_post_process_node(original_arg, named_modules): + original_arg = original_arg.args[0] # type: ignore[assignment] + assert isinstance( + original_arg, Node + ), f"expect original argument to be a Node, but got: {type(original_arg)}" + + input_edge = (original_arg, node) + if input_edge not in obs_or_fq_map: + return new_arg + # input_edge needs to be observed + input_edge_obs_or_fq = obs_or_fq_map[input_edge] + if input_edge_obs_or_fq is None: + return new_arg + + arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None) + # the arg is observed as the output and is using the same instance as the input_edge + # we'll reuse the inserted observer/fake_quant + if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id( + input_edge_obs_or_fq + ): + return new_arg + + # otherwise, we'll insert a new observer/fake_quant node + + # skip inserting new observers if the same observer instance is inserted before for another user + # Example: + # conv1 -> obs1 -> existing_obs -> conv2 + # \ -> conv3 + # + # instead of inserting new observers we will have: + # conv1 -> obs1 -> existing_obs -> conv2 + # \ -> conv3 + for maybe_obs_node in arg.users.keys(): + if not _is_activation_post_process_node(maybe_obs_node, named_modules): + continue + maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] + if id(maybe_obs_mod) == id(input_edge_obs_or_fq): + return maybe_obs_node + + assert isinstance(model.graph, Graph) + new_arg = _insert_obs_or_fq( + arg, input_edge_obs_or_fq, model, named_modules, model.graph + ) + return new_arg + + +def _maybe_insert_input_observers_for_node( + node: Node, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> None: + """ + If needed, inserts observers to the input args and kwargs of `node`. + Note: modifies `node` inplace. + + For example, if cur_node needs an observer after prev_node, we change from + + prev_node -> cur_node + + To + + prev_node -> obs -> cur_node + + """ + # Look through every input arg. If that arg's target dtype does not + # match the current node's target dtype, insert an observer. + new_args = [] + for arg in node.args: + new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, + arg, + qconfig, + model, + named_modules, + obs_or_fq_map, + is_qat, + ) + new_args.append(new_arg) + + # Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and + # gelu has a has an approximate kwarg that persist in exported graph. + # This is just a work around for these. + assert ( + node.target == torch.ops.aten.clone.default + or node.target == torch.ops.aten.zeros_like.default + or node.target == torch.ops.aten.gelu.default + or len(node.kwargs) == 0 + ), " expecting kwargs for aten op IR to be empty" + + # assign the new args to the node, inplace + node.args = tuple(new_args) + + +def _maybe_insert_output_observer_for_node( + node: Node, + model: torch.nn.Module, + named_modules: dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Optional[Node]: + if node in obs_or_fq_map: + output_act_obs_or_fq = obs_or_fq_map[node] + new_output = _insert_obs_or_fq( + node, output_act_obs_or_fq, model, named_modules, graph + ) + # propagate numeric debug handle from original node to observer/fake_quant node + if ( + isinstance(node, Node) + and isinstance(new_output, Node) + and CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + if CUSTOM_KEY not in new_output.meta: + new_output.meta[CUSTOM_KEY] = {} + new_output.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + CUSTOM_KEY + ][NUMERIC_DEBUG_HANDLE_KEY] + return new_output + return None + + +def _maybe_insert_input_and_output_observers_for_node( + node: Node, + model: torch.fx.GraphModule, + obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +): + this_node_quantization_annotation = ( + node.meta["quantization_annotation"] + if "quantization_annotation" in node.meta + else None + ) + if this_node_quantization_annotation is None: + return + + named_modules = dict(model.named_modules(remove_duplicate=False)) + _maybe_insert_input_observers_for_node( + node, + None, # qconfig + model, + named_modules, + obs_or_fq_map, + is_qat, + ) + + output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) + if not output_is_a_tensor: + return + + # this returns the new observer node if it was needed + maybe_output_obs_node = _maybe_insert_output_observer_for_node( + node, model, named_modules, model.graph, obs_or_fq_map, is_qat + ) + + if maybe_output_obs_node is None: + return + # Update users of original node to use the output observer + # instead. For example, change + # + # next_node + # / + # cur_node -> obs + # + # to + # + # next_node + # / + # cur_node -> obs + # + # We need to save orig users before updating uses because + # the list of users will change as we update uses + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is maybe_output_obs_node: + continue + user_node.replace_input_with(node, maybe_output_obs_node) + + +def prepare( + model: GraphModule, + node_name_to_scope: dict[str, tuple[str, type]], + is_qat: bool, + obs_or_fq_callback=None, +) -> GraphModule: + # Since we are mutating the graph as we go, we iterate over the original + # nodes before observer insertion, instead of model.graph.nodes. + nodes_before_observation = list(model.graph.nodes) + + # At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance + # all edge/nodes that belongs to the same group will use the same instance + # and when we insert observers we'll just query this map to get the correct observer_or_fake_quant + # instance + edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model) + edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) + obs_or_fq_map = _get_obs_or_fq_map( + edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat + ) + if obs_or_fq_callback: + obs_or_fq_callback(model, obs_or_fq_map) + + for node in nodes_before_observation: + # TODO: simplify logic for inserting observers + _maybe_insert_input_and_output_observers_for_node( + node, model, obs_or_fq_map, is_qat + ) + + model = GraphModule(model, model.graph) + + _save_state( + model, + {}, # node_name_to_qconfig + node_name_to_scope, + PrepareCustomConfig(), + {}, # equalization_node_name_to_qconfig + QConfigMapping(), + is_qat, + set(), # observed_node_names + ) + return model diff --git a/torchao/quantization/pt2e_flow/pt2e/qat_utils.py b/torchao/quantization/pt2e_flow/pt2e/qat_utils.py new file mode 100644 index 0000000000..e314ef4ff2 --- /dev/null +++ b/torchao/quantization/pt2e_flow/pt2e/qat_utils.py @@ -0,0 +1,991 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import itertools +import operator +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch +import torch.nn.functional as F +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.fx import Graph, GraphModule, Node +from torch.fx.subgraph_rewriter import ReplacedPatterns, replace_pattern_with_filters + +from torchao.quantization.pt2e_flow.pt2e.export_utils import _WrapperModule +from torchao.quantization.pt2e_flow.quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + QuantizationSpecBase, + SharedQuantizationSpec, +) + +from .utils import ( + _get_aten_graph_module_for_pattern, + _is_bn_node, + _is_conv_or_conv_transpose_node, + _is_conv_transpose_fn, + fold_bn_weights_into_conv_node, +) + +if TYPE_CHECKING: + from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch + +__all__ = [] # type: ignore[var-annotated] + + +def _get_quantized_conv_bn_example_inputs_kwargs( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + is_cuda: bool, +) -> dict[str, Any]: + """ + Optional example inputs for quantized and folded conv-bn patterns + used in convert, expressed as kwargs. + """ + kwargs = {} + # Per tensor quantization uses literals to represent scale and zero + # point, so there is no need to include them here as kwargs + if is_per_channel: + kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float) + kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int) + if has_bias and bias_is_quantized: + kwargs["bias_scale"] = torch.tensor([1], dtype=torch.float) + kwargs["bias_zero_point"] = torch.tensor([0], dtype=torch.int) + if has_bias: + kwargs["conv_bias"] = torch.randn(1) + if is_cuda: + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + kwargs[k] = v.cuda() + return kwargs + + +def _get_conv_bn_pattern(conv_fn: Callable) -> Callable: + def _conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + x = conv_fn(x, conv_weight, conv_bias) + x = F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True + ) + return x + + return _WrapperModule(_conv_bn_pattern) + + +# TODO: merge this with the `no_conv_bias` case +def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable: + def _qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + """ + Approximated method to fuse conv and bn. It requires only one forward pass. + conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std. + This is based on `nniqat.ConvBn2d._forward_approximate`. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0 + weight_shape[weight_in_channel_axis] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype) + x = conv_fn(x, scaled_weight, zero_bias) + x = x / scale_factor.reshape(bias_shape) + x = x + conv_bias.reshape(bias_shape) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=True, + eps=bn_eps, + ) + return x + + return _WrapperModule(_qat_conv_bn_pattern) + + +def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable: + def _qat_conv_bn_pattern_no_conv_bias( + x: torch.Tensor, + conv_weight: torch.Tensor, + # Not used, only for matching convenience + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + """ + Same as `_get_qat_conv_bn_pattern`, but handles the case with no conv bias. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0 + weight_shape[weight_in_channel_axis] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + x = conv_fn(x, scaled_weight, None) + x = x / scale_factor.reshape(bias_shape) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=True, + eps=bn_eps, + ) + return x + + return _WrapperModule(_qat_conv_bn_pattern_no_conv_bias) + + +def _append_qdq(x, is_per_channel, is_bias, kwargs): + """ + Helper function to append q-dq ops after `x`, using dummy values for the qparams + and qmin/qmax. We use dummy values here because we match with `ignore_literals=True` + and will manually replace these values after subgraph rewriting. + + Return the dq node. + """ + # Dummy args to be passed into q-dq ops + per_channel_axis = 0 + scale_key = "bias_scale" if is_bias else "weight_scale" + zp_key = "bias_zero_point" if is_bias else "weight_zero_point" + scale = kwargs[scale_key] if is_per_channel else 1.0 + zp = kwargs[zp_key] if is_per_channel else 0 + qmin = -127 + qmax = 127 + dtype = torch.int8 + + qd = torch.ops.quantized_decomposed + if is_per_channel: + x = qd.quantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype) + x = qd.dequantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype) + else: + x = qd.quantize_per_tensor(x, scale, zp, qmin, qmax, dtype) + x = qd.dequantize_per_tensor(x, scale, zp, qmin, qmax, dtype) + return x + + +def _get_quantized_qat_conv_bn_pattern( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + conv_fn: Callable, + bn_is_training: bool, +) -> Callable: + """ + Return the quantized version of QAT conv + BN pattern. + This is based on `nniqat.ConvBn2d._forward_approximate`, + used in QAT convert. We first match this pattern and replace + it with the normal [conv - bn] pattern, then fold the BN + weights into conv. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + + def _quantized_qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + scaled_weight = _append_qdq( + scaled_weight, + is_per_channel, + is_bias=False, + kwargs=kwargs, + ) + if has_bias: + zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype) + if bias_is_quantized: + zero_bias = _append_qdq( + zero_bias, + is_per_channel, + is_bias=True, + kwargs=kwargs, + ) + x = conv_fn(x, scaled_weight, zero_bias) + else: + x = conv_fn(x, scaled_weight, None) + x = x / scale_factor.reshape(bias_shape) + if has_bias: + x = x + kwargs["conv_bias"].reshape(bias_shape) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=bn_is_training, + eps=bn_eps, + ) + return x + + return _WrapperModule(_quantized_qat_conv_bn_pattern) + + +def _get_folded_quantized_qat_conv_bn_pattern( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + conv_fn: Callable, + bn_is_training: bool, +) -> Callable: + """ + Quantized QAT conv - bn pattern with bn weights being folded into conv. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + + def _folded_quantized_qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + conv_weight = _append_qdq( + conv_weight, + is_per_channel, + is_bias=False, + kwargs=kwargs, + ) + if has_bias: + bias = kwargs["conv_bias"] + if bias_is_quantized: + bias = _append_qdq( + bias, + is_per_channel, + is_bias=True, + kwargs=kwargs, + ) + else: + bias = None + x = conv_fn(x, conv_weight, bias) + x = F.batch_norm( + x, + bn_running_mean, + bn_running_var, + bn_weight, + bn_bias, + training=bn_is_training, + eps=bn_eps, + ) + return x + + return _WrapperModule(_folded_quantized_qat_conv_bn_pattern) + + +def _has_conv_bias_filter( + match: "InternalMatch", + original_graph: Graph, + pattern_graph: Graph, +) -> bool: + """ + Match filter for the subgraph rewriter that returns True if the conv node in + the original graph has bias. + """ + for n in match.nodes_map.values(): + if _is_conv_or_conv_transpose_node(n): + return len(n.args) > 2 and n.args[2] is not None + raise ValueError("Could not find conv node in matched conv + bn pattern") + + +def _no_conv_bias_filter( + match: "InternalMatch", + original_graph: Graph, + pattern_graph: Graph, +) -> bool: + """ + Match filter for the subgraph rewriter that returns True if the conv node in + the original graph does NOT have bias. + """ + return not _has_conv_bias_filter(match, original_graph, pattern_graph) + + +def _is_quantize(n: Node) -> bool: + return n.target in [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ] + + +def _is_dequantize(n: Node) -> bool: + return n.target in [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ] + + +def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> dict[str, tuple[Node, Node]]: + """ + Helper function to extract the nodes in the conv-bn fusion pattern after + subgraph rewriting, in the form of a map: + + {name: (original_node, replacement_node)} + + The following names must exist in the map: + + "conv", "conv_weight", "conv_input", "bn", "getitem" + + The following names may exist in the map: + + "conv_weight_q", "conv_weight_dq", "conv_bias", + "conv_bias_q", "conv_bias_dq" + """ + + def _get_nodes(nodes: list[Node]) -> tuple[Node, Node, Optional[Node]]: + """ + Return a 3-tuple of (conv_node, bn_node, getitem_node). + This asserts that the match contains exactly one of each node. + """ + conv_node, bn_node, getitem_node = None, None, None + for n in nodes: + if n.op != "call_function": + continue + if _is_conv_or_conv_transpose_node(n): + assert conv_node is None + conv_node = n + if _is_bn_node(n): + assert bn_node is None + bn_node = n + if n.target == operator.getitem: + assert getitem_node is None + getitem_node = n + assert conv_node is not None + assert bn_node is not None + return (conv_node, bn_node, getitem_node) + + def _get_q_dq_nodes(n: Node) -> tuple[Node, Node, Node]: + """ + Return a 3-tuple of (orig_node, q_node, dq_node). + """ + assert _is_dequantize(n) + q_node = n.args[0] + assert isinstance(q_node, Node) + assert _is_quantize(q_node) + orig_node = q_node.args[0] + assert isinstance(orig_node, Node) + return (orig_node, q_node, n) + + original_nodes = list(_filter_nodes_map(r.nodes_map).values()) + o_conv, o_bn, o_getitem = _get_nodes(original_nodes) + r_conv, r_bn, r_getitem = _get_nodes(r.replacements) + + # Create the mapping from original node to replacement node + assert o_getitem is None + assert r_getitem is None + mapping = { + "conv": (o_conv, r_conv), + "bn": (o_bn, r_bn), + } + + # Extract conv input and weight + # Note: here we extract the original nodes indirectly through the pattern nodes + # because the args of the original nodes are no longer available after replacement + (p_conv, _, _) = _get_nodes(list(r.nodes_map.keys())) + (p_conv_input, p_conv_weight, *_) = p_conv.args + (r_conv_input, r_conv_weight, *_) = r_conv.args + assert isinstance(p_conv_input, Node) + assert isinstance(p_conv_weight, Node) + assert isinstance(r_conv_input, Node) + assert isinstance(r_conv_weight, Node) + o_conv_input = r.nodes_map[p_conv_input] + o_conv_weight = r.nodes_map[p_conv_weight] + + # If conv weight is quantized, extract the q - dq nodes + if _is_dequantize(p_conv_weight): + p_conv_weight, p_conv_weight_q, p_conv_weight_dq = _get_q_dq_nodes( + p_conv_weight + ) + r_conv_weight, r_conv_weight_q, r_conv_weight_dq = _get_q_dq_nodes( + r_conv_weight + ) + o_conv_weight = r.nodes_map[p_conv_weight] + o_conv_weight_q = r.nodes_map[p_conv_weight_q] + o_conv_weight_dq = r.nodes_map[p_conv_weight_dq] + mapping["conv_weight_q"] = (o_conv_weight_q, r_conv_weight_q) + mapping["conv_weight_dq"] = (o_conv_weight_dq, r_conv_weight_dq) + mapping["conv_input"] = (o_conv_input, r_conv_input) + mapping["conv_weight"] = (o_conv_weight, r_conv_weight) + + # Extract conv bias + if len(p_conv.args) > 2 and len(r_conv.args) > 2: + p_conv_bias = p_conv.args[2] + r_conv_bias = r_conv.args[2] + assert isinstance(p_conv_bias, Node) + assert isinstance(r_conv_bias, Node) + o_conv_bias = r.nodes_map[p_conv_bias] + + # If conv bias is quantized, extract the q - dq nodes + if _is_dequantize(p_conv_bias): + p_conv_bias, p_conv_bias_q, p_conv_bias_dq = _get_q_dq_nodes(p_conv_bias) + r_conv_bias, r_conv_bias_q, r_conv_bias_dq = _get_q_dq_nodes(r_conv_bias) + o_conv_bias = r.nodes_map[p_conv_bias] + o_conv_bias_q = r.nodes_map[p_conv_bias_q] + o_conv_bias_dq = r.nodes_map[p_conv_bias_dq] + mapping["conv_bias_q"] = (o_conv_bias_q, r_conv_bias_q) + mapping["conv_bias_dq"] = (o_conv_bias_dq, r_conv_bias_dq) + mapping["conv_bias"] = (o_conv_bias, r_conv_bias) + return mapping + + +def _filter_nodes_map(nodes_map: dict[Node, Node]) -> dict[Node, Node]: + """ + Return a filtered `nodes_map` returned from the subgraph rewriter. + The filtered `nodes_map` will contain only nodes that are actually + matched in the pattern, excluding None or placeholder nodes. + """ + new_nodes_map: dict[Node, Node] = {} + for pattern_node, graph_node in nodes_map.items(): + # bias can be None + if graph_node is None: + continue + # skip pattern placeholder nodes + if pattern_node.op == "placeholder": + continue + new_nodes_map[pattern_node] = graph_node + return new_nodes_map + + +# TODO: this is error prone, use the replace_literals_with_placeholders hack instead +def _copy_over_literal_conv_args(original_node: Node, new_node: Node): + """ + Copy over literal args in conv, such as stride and padding, from the matched node + in the original graph to its replacement in the new graph. + + This is needed due to the following limitation in the subgraph rewriter when used + with dynamo export: literal (non-tensor) args are not supported in the match and + replacement patterns. This is because dynamo export automatically inlines these + literal args, making them dead placeholder nodes. In the future, we should check + if dynamo export can optionally disable this inlining, or if subgraph rewriter + can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419. + + Note: Unlike other tensor args like conv weights and biases, literal args are + preserved in the original nodes after replacement, so we can access them here. + """ + assert _is_conv_or_conv_transpose_node(original_node) + assert _is_conv_or_conv_transpose_node(new_node) + # x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups] + new_args = list(new_node.args) + if len(new_args) < 3: + # bias is optional, when it is not present, it means it is None + new_args.append(None) + new_node.args = tuple(new_args[:3]) + original_node.args[3:] + + +def _update_conv_input_qspec_map_after_replacement( + original_node: Node, replacement_node: Node +): + """ + Update the `input_qspec_map` in the annotation after subgraph rewriting. + + The original annotation referred to the nodes in the original graph, + so the keys in the `input_qspec_map` will need to be updated to reflect + the corresponding nodes in the replacement graph. + """ + assert _is_conv_or_conv_transpose_node(original_node) + assert _is_conv_or_conv_transpose_node(replacement_node) + if "quantization_annotation" not in original_node.meta: + return + original_input_qspec_map = original_node.meta[ + "quantization_annotation" + ].input_qspec_map + input_qspec_map = {} + # get the list of configs, it should be ordered as input, weight, bias + # note: this is really hacky, we need a better solution, hopefully + # in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820 + all_configs = list(original_input_qspec_map.items()) + # input activation + input_qspec_map[replacement_node.args[0]] = all_configs[0][1] + # weight + input_qspec_map[replacement_node.args[1]] = all_configs[1][1] + # bias + if len(replacement_node.args) > 2 and len(all_configs) > 2: + input_qspec_map[replacement_node.args[2]] = all_configs[2][1] + replacement_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map + + +def _update_special_qspecs_after_replacement( + node: Node, + original_to_replacement_node: dict[Node, Node], +): + """ + Update the `SharedQuantizationSpec`s and `DerivedQuantizationSpec`s + used in `node`'s quantization annotation after subgraph rewriting. + + The original annotation referred to the nodes in the original graph, + so the nodes used in these special quantization specs will need to + be updated to the corresponding nodes in the replacement graph. + """ + + def _get_new_edge_or_node(edge_or_node: EdgeOrNode): + if isinstance(edge_or_node, Node): + _node = edge_or_node + return original_to_replacement_node.get(_node, _node) + elif ( + isinstance(edge_or_node, tuple) + and len(edge_or_node) == 2 + and all(isinstance(x, Node) for x in edge_or_node) + ): + src, dest = edge_or_node + return ( + original_to_replacement_node.get(src, src), + original_to_replacement_node.get(dest, dest), + ) + else: + raise ValueError("unexpected type for edge_or_node: ", type(edge_or_node)) + + def _get_new_qspec(qspec: QuantizationSpecBase): + if isinstance(qspec, SharedQuantizationSpec): + new_edge_or_node = _get_new_edge_or_node(qspec.edge_or_node) + return SharedQuantizationSpec(new_edge_or_node) + elif isinstance(qspec, DerivedQuantizationSpec): + new_derived_from = [_get_new_edge_or_node(x) for x in qspec.derived_from] + return dataclasses.replace(qspec, derived_from=new_derived_from) + else: + return qspec + + if "quantization_annotation" not in node.meta: + return + annotation = node.meta["quantization_annotation"] + for input_node, qspec in annotation.input_qspec_map.items(): + annotation.input_qspec_map[input_node] = _get_new_qspec(qspec) + annotation.output_qspec = _get_new_qspec(annotation.output_qspec) + + +def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: + # Example inputs for conv-bn1d patterns + _conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for conv-bn2d patterns + _conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return m + is_cuda_options = [True, False] if torch.cuda.is_available() else [False] + for is_cuda in is_cuda_options: + m = _fuse_conv_bn_qat_helper( + m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fuse_conv_bn_qat_helper( + m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fuse_conv_bn_qat_helper( + m, F.conv_transpose1d, _conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fuse_conv_bn_qat_helper( + m, F.conv_transpose2d, _conv2d_bn_example_inputs, is_cuda=is_cuda + ) + return m + + +def _fuse_conv_bn_qat_helper( + m: GraphModule, + conv_fn: Callable, + example_inputs: tuple[Any, ...], + is_cuda: bool, +) -> GraphModule: + """ + Given a graph of decomposed aten ops, replace the (conv + bn) pattern with + the fused QAT subgraph equivalent. The input graph should already be annotated. + The annotations in the original nodes will be preserved in the corresponding + nodes in the new subgraph. + + Note: This also handles the (conv + bn + relu) pattern. + """ + m.graph.eliminate_dead_code() + m.recompile() + + conv_bn_pattern = _get_conv_bn_pattern(conv_fn) + match_pattern = _get_aten_graph_module_for_pattern( + conv_bn_pattern, + example_inputs, + is_cuda, + ) + + # Step (1): Replace patterns with conv bias + # + # Here we do replacement separately for cases with and without conv bias, since + # the replacement patterns for these two cases are substantially different. + # TODO: use the public replace_pattern API once it also returns replacement nodes + + qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn) + replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern( + qat_conv_bn_pattern, + example_inputs, + is_cuda, + ) + replacements_with_conv_bias = replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern_with_conv_bias, + match_filters=[_has_conv_bias_filter], + ignore_literals=True, + ) + m.recompile() + + # Step (2): Replace patterns without conv bias + + qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn) + replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern( + qat_conv_bn_pattern_no_conv_bias, + example_inputs, + is_cuda, + ) + replacements_no_conv_bias = replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern_no_conv_bias, + match_filters=[_no_conv_bias_filter], + ignore_literals=True, + ) + m.recompile() + + # Step (3): Post processing + # + # Due to limited functionality in the subgraph rewriter, here we manually + # update the replacement graph as follows: + # + # (a) Copy over metadata from original subgraph. This ensures the stack traces + # and annotations are preserved in the new subgraph + # + # (b) Copy over literal args for conv from the original subgraph + # TODO: do this for literal args for batchnorm as well + # + # (c) Update all references of the old nodes in the original subgraph to refer + # to the corresponding nodes in the new subgraph in the annotations + # + # In the future, we should try to push as much of this functionality into the + # subgraph rewriter as possible, so we don't have to manually copy anything over. + # For more detail, see https://github.com/pytorch/pytorch/issues/100419. + + all_original_to_replacement_nodes = {} + for r in replacements_with_conv_bias + replacements_no_conv_bias: + replacement_dict = _get_conv_bn_pattern_nodes(r) + # The original conv node's "nn_module_stack" + conv_nn_module = replacement_dict["conv"][0].meta.get("nn_module_stack", None) + for k, node_tuple in replacement_dict.items(): + original_node, replacement_node = node_tuple + # Step (3a): Copy over metadata for all nodes in [conv - bn - getitem] + replacement_node.meta = original_node.meta + # If original_node is a get_attr node, it doesn't have nn_module_stack. + # In this case, we copy nn_module_stack from the original conv node. + if ( + k in ["conv_input", "conv_weight"] + and conv_nn_module + and "nn_module_stack" not in replacement_node.meta + ): + replacement_node.meta["nn_module_stack"] = copy.deepcopy(conv_nn_module) + if _is_conv_or_conv_transpose_node(original_node): + # Step (3b): Copy over conv literal args + _copy_over_literal_conv_args(original_node, replacement_node) + # Step (3c): Update old references in the conv node's input_qspec_map + _update_conv_input_qspec_map_after_replacement( + original_node, replacement_node + ) + all_original_to_replacement_nodes[original_node] = replacement_node + + # Step (3c): Update old references in the special qspecs for all nodes in the graph + for n in m.graph.nodes: + _update_special_qspecs_after_replacement(n, all_original_to_replacement_nodes) + + return m + + +def _duplicate_dequantize_node(m: GraphModule): + """ + Helper function to duplicate all dequantize nodes in the graph if the + node has more than one user. For example: + + Before: + quantize -> dequantize -> a + \\--> b + \\--> c + + After: + quantize -> dequantize_1 -> a + \\--> dequantize_2 -> b + \\--> dequantize_3 -> c + + This is useful for subgraph rewriting. E.g. if we wish to match the + pattern [dequantize - a] above, subgraph matching would fail because + the dequantize node has users outside the matched portion of the graph. + Instead, we match [dequantize_1 - a], which is safe. + """ + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor + for n in m.graph.nodes: + if n.op != "call_function" or n.target != dq_op or len(n.users) == 1: + continue + for user in list(n.users): + with m.graph.inserting_before(n): + new_node = m.graph.create_node("call_function", dq_op, n.args, n.kwargs) + user.replace_input_with(n, new_node) + m.graph.erase_node(n) + m.recompile() + + +def _remove_extra_dequantize(m: GraphModule): + """ + Removes duplicate dequant nodes in the graph, for an operator that has + multiple dequant nodes as a user, replace them with a single dequant node + that can be shared across all the uses. This should be seen as the "reverse" + of `_duplicate_dequantize_node`. + """ + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor + for n in m.graph.nodes: + dq_users = [ + user + for user in n.users + if user.op == "call_function" and user.target == dq_op + ] + if len(dq_users) > 1: + with m.graph.inserting_after(dq_users[0]): + new_node = m.graph.create_node( + "call_function", dq_op, dq_users[0].args, {} + ) + for dq_user in dq_users: + dq_user.replace_all_uses_with(new_node) + m.graph.erase_node(dq_user) + m.recompile() + + +def _copy_over_q_dq_args(original_node: Node, replacement_node: Node): + """ + Given a pair of quantize or dequantize nodes, copy over all literal args + from the original node to the replacement node. + """ + # For quantize_per_tensor, scale and zp are literals and need to be copied + # For quantize_per_channel, scale and zp are get_attr nodes and should be skipped + assert original_node.target == replacement_node.target + if original_node.target in ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ): + # Args: input, [scale, zp, qmin, qmax, dtype] + start_copy_arg_index = 1 + elif original_node.target in ( + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ): + # Args: input, scale, zp, [axis, qmin, qmax, dtype] + start_copy_arg_index = 3 + else: + raise ValueError( + f"Expected quantize/dequantize nodes, got '{original_node.target}'" + ) + replacement_node.args = ( + replacement_node.args[:start_copy_arg_index] + + original_node.args[start_copy_arg_index:] + ) + + +def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: + # Example inputs for quantized and folded conv-bn1d patterns used in convert + _quantized_conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for quantized and folded conv-bn2d patterns used in convert + _quantized_conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return m + is_cuda_options = [True, False] if torch.cuda.is_available() else [False] + for is_cuda in is_cuda_options: + m = _fold_conv_bn_qat_helper( + m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fold_conv_bn_qat_helper( + m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fold_conv_bn_qat_helper( + m, F.conv_transpose1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda + ) + m = _fold_conv_bn_qat_helper( + m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda + ) + + # remove in place add from batchnorm tracking traning stats + for node in m.graph.nodes: + if ( + node.target == torch.ops.aten.add_.Tensor + and node.args[0].op == "get_attr" + and node.args[1] == 1 + and torch.nn.modules.batchnorm.BatchNorm2d + in [val[1] for val in node.meta["source_fn_stack"]] + ): + m.graph.erase_node(node) + + m.graph.eliminate_dead_code() + m.recompile() + + return m + + +def _fold_conv_bn_qat_helper( + m: GraphModule, + conv_fn: Callable, + example_inputs: tuple[Any, ...], + is_cuda: bool, +) -> GraphModule: + """ + Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv. + """ + + m.graph.eliminate_dead_code() + m.recompile() + _duplicate_dequantize_node(m) + + # Step (1): Replace QAT pattern with simple [conv - bn] pattern + replacements = [] + replacement_options = itertools.product( + [True, False], # is_per_channel + [True, False], # has_bias + [True, False], # bias_is_quantized + [True, False], # bn_is_training + ) + for ( + is_per_channel, + has_bias, + bias_is_quantized, + bn_is_training, + ) in replacement_options: + # For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily + # filter out one of the values for this flag to avoid having duplicate patterns + if not has_bias and bias_is_quantized: + continue + kwargs = _get_quantized_conv_bn_example_inputs_kwargs( + is_per_channel, has_bias, bias_is_quantized, is_cuda + ) + match_pattern = _get_quantized_qat_conv_bn_pattern( + is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training + ) + match_pattern = _get_aten_graph_module_for_pattern( + match_pattern, + example_inputs, + is_cuda, + **kwargs, + ) + replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern( + is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training + ) + replacement_pattern = _get_aten_graph_module_for_pattern( + replacement_pattern, + example_inputs, + is_cuda, + **kwargs, + ) + replacements.extend( + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + ignore_literals=True, + ) + ) + m.recompile() + _remove_extra_dequantize(m) + + for r in replacements: + node_map = _get_conv_bn_pattern_nodes(r) + + # Step (2): Copy over metadata from original subgraph + for original_node, replacement_node in node_map.values(): + replacement_node.meta = original_node.meta + + # Step (3): Copy over args for weight (and optionally bias) q - dq nodes + _copy_over_q_dq_args(*node_map["conv_weight_q"]) + _copy_over_q_dq_args(*node_map["conv_weight_dq"]) + if "conv_bias_q" in node_map: + assert "conv_bias_dq" in node_map + _copy_over_q_dq_args(*node_map["conv_bias_q"]) + _copy_over_q_dq_args(*node_map["conv_bias_dq"]) + + # Step (4): Fold BN weights into conv + conv_bias = None + (_, conv_node) = node_map["conv"] + (_, bn_node) = node_map["bn"] + (_, conv_weight) = node_map["conv_weight"] + if "conv_bias" in node_map: + (_, conv_bias) = node_map["conv_bias"] + fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m) + + # Copy over literal args for conv + for original_node in _filter_nodes_map(r.nodes_map).values(): + if _is_conv_or_conv_transpose_node(original_node): + _copy_over_literal_conv_args(original_node, conv_node) + + m.graph.eliminate_dead_code() + m.recompile() + return m diff --git a/torchao/quantization/pt2e_flow/pt2e/representation/__init__.py b/torchao/quantization/pt2e_flow/pt2e/representation/__init__.py new file mode 100644 index 0000000000..9ddac64c04 --- /dev/null +++ b/torchao/quantization/pt2e_flow/pt2e/representation/__init__.py @@ -0,0 +1,5 @@ +from .rewrite import reference_representation_rewrite + +__all__ = [ + "reference_representation_rewrite", +] diff --git a/torchao/quantization/pt2e_flow/pt2e/representation/rewrite.py b/torchao/quantization/pt2e_flow/pt2e/representation/rewrite.py new file mode 100644 index 0000000000..e0fbb0a416 --- /dev/null +++ b/torchao/quantization/pt2e_flow/pt2e/representation/rewrite.py @@ -0,0 +1,819 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Optional + +import torch +from torch._higher_order_ops.out_dtype import out_dtype +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.fx import GraphModule +from torch.fx.subgraph_rewriter import replace_pattern + +from torchao.quantization.pt2e_flow.pt2e.export_utils import _WrapperModule +from torchao.quantization.pt2e_flow.pt2e.utils import ( + _get_aten_graph_module_for_pattern, + _replace_literals_with_existing_placeholders, + _replace_literals_with_new_placeholders, + remove_tensor_overload_for_qdq_ops, +) + +__all__ = [ + "reference_representation_rewrite", +] + + +def _qdq_quantized_linear( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + torch.int8, + ) + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_linear( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max) + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None, + ) + # TODO: change to mul.Scalar + # Note: we are quantizing bias with these scales without signal from user, but it might be OK + bias_scale = x_scale * weight_scale + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + acc_i32 = acc_i32 + bias_i32 + # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values + acc_i32 = ( + out_dtype( + torch.ops.aten.mul.Tensor, + torch.int32, + acc_i32, + x_scale * weight_scale / out_scale, + ) + + out_zero_point + ) + out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8) + return out_i8 + + +def _qdq_dynamic_quantized_linear( + x_fp32, + x_quant_min, + x_quant_max, + x_eps, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, +): + x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams( + x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8 + ) + x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + torch.int8, + ) + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + return out_fp32 + + +def _reference_dynamic_quantized_linear( + x_fp32, + x_quant_min, + x_quant_max, + x_eps, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, +): + x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams( + x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8 + ) + # decomposed representation for quantize_per_tensor + # TODO: use out_dtype(mul, ...) here when the op is ready + x_fp32 = x_fp32 / x_scale # fp32 + # round modes might be different here + # pytorch is rounding to even, which is also common for most of the backends + x_fp32 = torch.round(x_fp32) # fp32 + x_i32 = x_fp32.to(dtype=torch.int32) # int32 + x_i32 = x_i32 + x_zero_point # int32 + # clamp works for fp32, int32 and int8 dtypes + x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max) # int32 + x_i8 = x_i32.to(dtype=torch.int8) + + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None, + ) + bias_scale = x_scale * weight_scale + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + acc_i32 = acc_i32 + bias_i32 + out_fp32 = acc_i32 * (x_scale * weight_scale) + return out_fp32 + + +def _qdq_quantized_conv2d( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + transposed = False + output_padding = [0, 0] + groups = 1 + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + torch.int8, + ) + out_fp32 = torch.ops.aten.convolution.default( + x_fp32, + weight_fp32, + bias_fp32, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_conv2d( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + weight_i8, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + transposed = False + output_padding = [0, 0] + groups = 1 + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max) + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.convolution.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + # Note: we are quantizing bias with these scales without signal from user, but it might be OK + bias_scale = x_scale * weight_scale + # bias quantization to int32 uses bias_scale = x_scale * weight_scale due to: + # Take linear calculation for example + # Out_(i, j)_fp32 = Sum_(over k)[X_(i, k)_fp32 * W_(i, k)_fp32] + bias_(i)_fp32 + # Represent X, W fp32 as their dequant transforms + # A_fp32 = (A_q - A_zero_point)/A_scale + # Out_(i, j)_fp32 = Sum_(over k)[(X_(i, k)_fp32 - X_zp) * X_scale * (W_(i, k)_fp32 - W_zp) * W_scale] + bias_(i)_fp32 + # Factor out X_scale and W_scale + # Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32 + # In order to addition of bias_(i)_fp32 inside, we must do + # Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale # noqa: B950 + # Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale + # Thus bias quantization to int32 must be with X_scale * W_scale + + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + # Unsqueeze to match broadcast dims + # Unfortnuately I cannot do bias_i32.unsqueeze(0) due to literal matching nightmare + # in graph pattern replacement + bias_i32 = bias_i32.unsqueeze(-1) + bias_i32 = bias_i32.unsqueeze(-1) + acc_i32 = acc_i32 + bias_i32 + # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values + acc_i32 = ( + out_dtype( + torch.ops.aten.mul.Tensor, + torch.int32, + acc_i32, + x_scale * weight_scale / out_scale, + ) + + out_zero_point + ) + out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8) + return out_i8 + + +def _qdq_quantized_add_relu( + x_i8, + x_scale, + x_zero_point, + y_i8, + y_scale, + y_zero_point, + out_scale, + out_zero_point, + quant_min, + quant_max, +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8 + ) + y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8 + ) + out_fp32 = x_fp32 + y_fp32 + out_fp32 = torch.ops.aten.relu(out_fp32) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_add_relu( + x_i8, + x_scale, + x_zero_point, + y_i8, + y_scale, + y_zero_point, + out_scale, + out_zero_point, + quant_min, + quant_max, +): + """ + See comments for `_reference_quantized_add` for more information on + how to derive the formula for out_i8 based on x_i8 and y_i8 + """ + x_i32 = x_i8.to(torch.int32) + y_i32 = y_i8.to(torch.int32) + # TODO: change this to mul.Scalar? + x_i32 = out_dtype( + torch.ops.aten.mul.Tensor, + torch.int32, + (x_i32 - x_zero_point), + (x_scale / out_scale), + ) + y_i32 = out_dtype( + torch.ops.aten.mul.Tensor, + torch.int32, + (y_i32 - y_zero_point), + (y_scale / out_scale), + ) + out_i32 = x_i32 + y_i32 + out_zero_point + # out_i32 = torch.ops.aten.clamp(out_i32, out_zero_point) + out_i8 = torch.ops.aten.clamp(out_i32, out_zero_point, quant_max).to(torch.int8) + return out_i8 + + +def _qdq_quantized_add( + x_i8, + x_scale, + x_zero_point, + y_i8, + y_scale, + y_zero_point, + out_scale, + out_zero_point, + quant_min, + quant_max, +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8 + ) + y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8 + ) + out_fp32 = x_fp32 + y_fp32 + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_add( + x_i8, + x_scale, + x_zero_point, + y_i8, + y_scale, + y_zero_point, + out_scale, + out_zero_point, + quant_min, + quant_max, +): + """ + # How to Derive the formula for out_i8 based on x_i8 and y_i8 + # (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8) + + # out_i8 is quantized output, we can write down the formula for it first: + out_i8 = out_f32 / out_scale + out_zero_point (1) + + # then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8 + out_f32 = x_f32 + y_f32 (2) + x_fp32 = (x_i8 - x_zero_point) * x_scale (3) + y_fp32 = (y_i8 - y_zero_point) * y_scale (4) + + # applying the above fomula to the out_i8 equation we can get the following: + out_i8 = out_fp32 / out_scale + out_zero_point # (1) + = (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32 + = ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4) + """ + x_i32 = x_i8.to(torch.int32) + y_i32 = y_i8.to(torch.int32) + # TODO: use out_dtype op + x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32) + y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32) + out_i32 = x_i32 + y_i32 + out_zero_point + quant_min = -128 + quant_max = 127 + out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8) + return out_i8 + + +def _qdq_quantized_max_pool2d( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + kernel_size = 1 + stride = 1 + padding = 0 + dilation = 1 + ceil_mode = False + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 + ) + out_fp32, _ = torch.ops.aten.max_pool2d_with_indices.default( + x_fp32, kernel_size, stride, padding, dilation, ceil_mode + ) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantized_max_pool2d( + x_i8, + x_scale, + x_zero_point, + x_quant_min, + x_quant_max, + out_scale, + out_zero_point, + out_quant_min, + out_quant_max, +): + kernel_size = 1 + stride = 1 + padding = 0 + dilation = 1 + ceil_mode = False + # to preserve x_quant_min, x_quant_max in the graph for pattern matching + x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max) + x_i32 = x_i8.to(torch.int32) + out_i32, _ = torch.ops.aten.max_pool2d_with_indices.default( + x_i32 - x_zero_point, kernel_size, stride, padding, dilation, ceil_mode + ) + out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point + out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max) + out_i8 = out_fp32.to(torch.int8) + return out_i8 + + +def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max): + x = torch.ops.quantized_decomposed.quantize_per_tensor( + x_fp32, scale, zero_point, quant_min, quant_max, torch.int8 + ) + return x + + +def _reference_quantize_per_tensor_int8( + x_fp32, scale, zero_point, quant_min, quant_max +): + # TODO: use out_dtype(mul, ...) here when the op is ready + x = x_fp32 / scale # fp32 + # round modes might be different here + # pytorch is rounding to even, which is also common for most of the backends + x = torch.round(x) # fp32 + x = x.to(dtype=torch.int32) # int32 + x = x + zero_point # int32 + # clamp works for fp32, int32 and int8 dtypes + x = torch.clamp(x, quant_min, quant_max) # int32 + x = x.to(dtype=torch.int8) + return x + + +def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, scale, zero_point, quant_min, quant_max, torch.int8 + ) + return x_fp32 + + +def _reference_dequantize_per_tensor_int8( + x_i8, scale, zero_point, quant_min, quant_max +): + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max) + # TODO: use out_dtype op + # note: x_i8.to(torch.int32) does not work here + # TODO: debug the implementation later when torchdynamo time out issue is resolved + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + + +def _quantize_per_channel_int8( + x_fp32, scales, zero_points, ch_axis, quant_min, quant_max +): + out_i8 = torch.ops.quantized_decomposed.quantize_per_channel( + x_fp32, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8 + ) + return out_i8 + + +def _reference_quantize_per_channel_int8( + x_fp32, scales, zero_points, ch_axis, quant_min, quant_max +): + x_fp32 = torch.transpose(x_fp32, ch_axis, -1) + out_i32 = torch.ops.aten.clamp( + torch.round(x_fp32 / scales).to(torch.int32) + zero_points, quant_min, quant_max + ) + out_i32 = torch.transpose(out_i32, ch_axis, -1) + return out_i32.to(torch.int8) + + +def _dequantize_per_channel_int8( + x_i8, scales, zero_points, ch_axis, quant_min, quant_max +): + # the following will be replaced as placeholders + out_fp32 = torch.ops.quantized_decomposed.dequantize_per_channel( + x_i8, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8 + ) + return out_fp32 + + +def _reference_dequantize_per_channel_int8( + x_i8, scales, zero_points, ch_axis, quant_min, quant_max +): + # the following will be replaced as placeholders + # in order to preserve the quant_min/quant_max args for pattern matching (e.g. matching for int4 quantized ops) + # we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max) + x_i8 = torch.transpose(x_i8, ch_axis, -1) + x_i32 = x_i8.to(torch.int32) + out_fp32 = (x_i32 - zero_points).to(torch.float) * scales + out_fp32 = torch.transpose(out_fp32, ch_axis, -1) + return out_fp32 + + +def _replace_ph_qdq_per_channel_replacement(gm: torch.fx.GraphModule): + return _replace_literals_with_existing_placeholders( + gm, exclude_literals=[-1], literal_to_ph_idx={1: 3, -128: 4, 127: 5} + ) + + +@dataclass +class _RewriteInfo: + """Data needed for rewrite, this includes example inputs, pattern and replacement functions + and post transformation functions for the exported pattern and replacement GraphModule + """ + + # example inputs used for exporting the pattern into GraphModule + example_inputs: tuple[Any, ...] + pattern: Callable + replacement: Callable + # post transformation on the exported pattern and replacement GraphModule + pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None + replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None + + +def reference_representation_rewrite(model: GraphModule) -> GraphModule: + _QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (2, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randint(-128, 127, (5, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( + torch.randn((2, 5), dtype=torch.float), + -128, + 127, + torch.finfo(torch.float32).eps, + torch.randint(-128, 127, (5, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + ) + + _QUANTIZED_CONV2d_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( + torch.randn(1, 3, 3, 3, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + ) + + _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( + torch.randn(1, 3, 3, 3, dtype=torch.float), + torch.randn(3, dtype=torch.float), + torch.zeros(3, dtype=torch.int), + 1, + -128, + 127, + ) + + _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(3, dtype=torch.float), + torch.zeros(3, dtype=torch.int), + 1, + -128, + 127, + ) + + _REWRITE_INFO_LIST = [ + _RewriteInfo( + _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS, + _WrapperModule(_qdq_dynamic_quantized_linear), + _WrapperModule(_reference_dynamic_quantized_linear), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, + ), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, + ), + ), + _RewriteInfo( + _QUANTIZED_LINEAR_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_linear), + _WrapperModule(_reference_quantized_linear), + _replace_literals_with_new_placeholders, + _replace_literals_with_new_placeholders, + ), + _RewriteInfo( + _QUANTIZED_CONV2d_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_conv2d), + _WrapperModule(_reference_quantized_conv2d), + partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), + partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), + ), + _RewriteInfo( + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_add_relu), + _WrapperModule(_reference_quantized_add_relu), + ), + _RewriteInfo( + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_add), + _WrapperModule(_reference_quantized_add), + ), + _RewriteInfo( + _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_max_pool2d), + _WrapperModule(_reference_quantized_max_pool2d), + _replace_literals_with_new_placeholders, + _replace_literals_with_new_placeholders, + ), + _RewriteInfo( + _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, + _WrapperModule(_quantize_per_tensor_int8), + _WrapperModule(_reference_quantize_per_tensor_int8), + ), + _RewriteInfo( + _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, + _WrapperModule(_dequantize_per_tensor_int8), + _WrapperModule(_reference_dequantize_per_tensor_int8), + ), + _RewriteInfo( + _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, + _WrapperModule(_quantize_per_channel_int8), + _WrapperModule(_reference_quantize_per_channel_int8), + _replace_ph_qdq_per_channel_replacement, + _replace_ph_qdq_per_channel_replacement, + ), + _RewriteInfo( + _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, + _WrapperModule(_dequantize_per_channel_int8), + _WrapperModule(_reference_dequantize_per_channel_int8), + _replace_ph_qdq_per_channel_replacement, + _replace_ph_qdq_per_channel_replacement, + ), + ] + + remove_tensor_overload_for_qdq_ops(model) + + for rewrite_info in _REWRITE_INFO_LIST: + example_inputs = rewrite_info.example_inputs + pattern = rewrite_info.pattern + replacement = rewrite_info.replacement + pattern_post_trans = rewrite_info.pattern_post_trans + replacement_post_trans = rewrite_info.replacement_post_trans + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment] + remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] + replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment] + remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] + if pattern_post_trans: + pattern = pattern_post_trans(pattern) + if replacement_post_trans: + replacement = replacement_post_trans(replacement) + pattern.recompile() # type: ignore[attr-defined] + replacement.recompile() # type: ignore[attr-defined] + replace_pattern(model, pattern, replacement) + + return model diff --git a/torchao/quantization/pt2e_flow/pt2e/utils.py b/torchao/quantization/pt2e_flow/pt2e/utils.py new file mode 100644 index 0000000000..e92195b591 --- /dev/null +++ b/torchao/quantization/pt2e_flow/pt2e/utils.py @@ -0,0 +1,610 @@ +# mypy: allow-untyped-defs +import operator +import types +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F + +# Makes sure that quantized_decomposed ops are registered +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.export.unflatten import _assign_attr, _AttrKind +from torch.fx import GraphModule, Node +from torch.nn.utils.fusion import fuse_conv_bn_weights +from torch.utils._pytree import LeafSpec + +import torchao.quantization.pt2e_flow.pt2e._affine_quantization # noqa: F401 +from torchao.quantization.pt2e_flow.quantizer import QuantizationAnnotation + +__all__ = [ + "fold_bn_weights_into_conv_node", + "remove_tensor_overload_for_qdq_ops", +] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool: + """ + Assuming dest is one of the ops inserted by quant workflow, this function + finds if source and dest are connected. Assumption is that only quant workflow + inserted ops exist between source and dest + """ + quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS + quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor) + while dest.target in quant_workflow_ops: + if not isinstance(dest.args[0], torch.fx.Node): + raise ValueError( + f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}" + ) + dest = dest.args[0] + return dest == source + + +def _find_q_dq_node_for_user( + produer: torch.fx.Node, user: torch.fx.Node +) -> tuple[Any, Any]: + """ + Find q, dq pair corresponding to [producer -> q -> dq -> user] + Utils works by finding dq arg of user and ensuring it is connected to + producer + """ + dq_node = None + for n in user.args: + if ( + isinstance(n, torch.fx.Node) + and n.op == "call_function" + and n.target in _DEQUANTIZE_OPS + ): + if _is_connected(produer, n): + dq_node = n + break + if dq_node is None: + for n in user.kwargs: + if ( + isinstance(n, torch.fx.Node) + and n.op == "call_function" + and n.target in _DEQUANTIZE_OPS + ): + if _is_connected(produer, n): + dq_node = n + break + if dq_node is None: + return (None, None) + + q_node = None + if ( + dq_node.args[0].op == "call_function" # type: ignore[union-attr] + and dq_node.args[0].target in _QUANTIZE_OPS # type: ignore[union-attr] + ): + q_node = dq_node.args[0] + return (q_node, dq_node) + + +def _is_sym_size_node(node: Node): + return ( + node.op == "call_function" + and node.target == torch.ops.aten.sym_size.default + or node.target == torch.ops.aten.sym_numel.default + or node.target == torch.ops.aten.sym_numel + or node.target == torch.ops.aten.sym_size + ) + + +def _filter_sym_size_users(node: torch.fx.Node) -> list[torch.fx.Node]: + node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users)) + return node_users + + +def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool: + if annotation is None: + return False + input_qspec_map = annotation.input_qspec_map + output_qspec = annotation.output_qspec + if len(input_qspec_map) == 0 and output_qspec is None: + return False + return True + + +def _get_tensor_constant_from_node(node, m): + if node is None: + return None + assert node.op == "get_attr" + target_atoms = node.target.split(".") + attr_itr = m + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def _get_all_arguments(orig_args, orig_kwargs, args_schema): + all_args = [] + for i, schema in enumerate(args_schema): + if schema.name in orig_kwargs: + all_args.append(orig_kwargs[schema.name]) + elif not schema.kwarg_only and i < len(orig_args): + all_args.append(orig_args[i]) + else: + all_args.append(schema.default_value) + return all_args + + +def _is_supported_batch_norm_for_training(node: Node): + """ + Return True if the given node refers to an aten batch norm op QAT supports. + """ + supported_ops = [ + torch.ops.aten.batch_norm.default, + torch.ops.aten._native_batch_norm_legit.default, + # Note: we won't need this op anymore after batch norm consolidation + # For now, we need to continue to support it because it gives better + # training numerics than `_native_batch_norm_legit` + torch.ops.aten.cudnn_batch_norm.default, + torch.ops.aten.miopen_batch_norm.default, + ] + return node.target in supported_ops + + +# TODO: move this to torch/ao/quantization/utils.py +def _is_conv_node(n: Node): + """ + Return whether the node refers to an aten conv op. + """ + return n.op == "call_function" and n.target in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ] + + +def _is_conv_transpose_node(n: Node): + """ + Return whether the node refers to an aten conv_transpose op. + """ + return n.op == "call_function" and n.target in [ + torch.ops.aten.conv_transpose1d, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d, + torch.ops.aten.conv_transpose2d.input, + ] + + +def _is_conv_or_conv_transpose_node(n: Node): + """ + Return whether the node refers to an aten conv or conv transpose op. + """ + return _is_conv_node(n) or _is_conv_transpose_node(n) + + +def _is_conv_transpose_fn(conv_fn: Callable): + return conv_fn in [F.conv_transpose1d, F.conv_transpose2d] + + +def _is_bn_node(n: Node): + return ( + _is_supported_batch_norm_for_training(n) + or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default + ) + + +def fold_bn_weights_into_conv_node( + conv_node: Node, + conv_weight_node: Node, + conv_bias_node: Optional[Node], + bn_node: Node, + m: GraphModule, +) -> None: + # conv args: input, weight, bias, stride, padding, dilation, ... + conv_w = _get_tensor_constant_from_node(conv_weight_node, m) + conv_b = _get_tensor_constant_from_node(conv_bias_node, m) + transpose = _is_conv_transpose_node(conv_node) + + # eval bn args: input, weight, bias, running mean, running var, momentum, eps + # train bn args: input, weight, bias, running mean, running var, training, momentum, eps + bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr] + bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema) + bn_w = _get_tensor_constant_from_node(bn_args[1], m) + bn_b = _get_tensor_constant_from_node(bn_args[2], m) + bn_rm = _get_tensor_constant_from_node(bn_args[3], m) + bn_rv = _get_tensor_constant_from_node(bn_args[4], m) + if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default: + eps_arg_index = 6 + elif _is_supported_batch_norm_for_training(bn_node): + eps_arg_index = 7 + else: + raise ValueError("BN node target is unexpected ", bn_node.target) + bn_eps = bn_args[eps_arg_index] + + fused_weight, fused_bias = fuse_conv_bn_weights( + conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose + ) + + # update the weight and bias for conv + conv_args = list(conv_node.args) + # filling in the default bias argument + if len(conv_args) == 2: + conv_args.append(None) + + # calling data since the fused_weight and fused_bias are nn.Parameter + weight_attr_name = conv_weight_node.target + assert isinstance(weight_attr_name, str) + _assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER) + if conv_bias_node is not None: + bias_attr_name = conv_bias_node.target + _assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER) + else: + bias_attr_name = weight_attr_name + "_bias" + _assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER) + with m.graph.inserting_before(conv_node): + get_bias_node = m.graph.get_attr(bias_attr_name) + # NOTE: here we assume the bias of conv is not quantized! + conv_args[2] = get_bias_node + conv_node.args = tuple(conv_args) + + # native_batch_norm has 3 outputs, we expect getitem calls on the output + # and we want to replace the uses of getitem 0 with the output of conv + # + if bn_node.target == torch.ops.aten.batch_norm.default: + # With the new training ir, instead of batch_norm + getitem, + # we only have the batch_norm node. + # + # Before: + # conv -> bn -> users + # After: + # conv -> users + # bn has no users now + bn_node.replace_all_uses_with(conv_node) + else: + # Before: + # conv -> bn - (first output) -> users1 + # \ - (second output) -> users2 + # \ - (third output) -> users3 + # After: + # conv -> (first output) -> users1 + # bn - + # \ - (second output) -> users2 + # \ - (third output) -> users3 + # if users2 and users3 are empty then bn will be removed through dead code elimination + for user in bn_node.users: + if ( + user.op != "call_function" + or user.target != operator.getitem + or user.args[1] != 0 + ): + continue + user.replace_all_uses_with(conv_node) + + # If the BN node does not have users, erase it from the graph + # Note: we need to do this manually because the model can still be in train + # mode at this point, in which case DCE won't erase the BN node automatically + # since the node refers to a mutating op. Here we still need to call DCE first + # to get rid of the unused getitem nodes that consume the BN node. + m.graph.eliminate_dead_code() + if len(bn_node.users) == 0: + m.graph.erase_node(bn_node) + + +# fuse conv bn weights, inplace modification of the graph_module and graph +def _fuse_conv_bn_(m: GraphModule) -> None: + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return + for n in m.graph.nodes: + if n.op != "call_function" or n.target not in ( + torch.ops.aten._native_batch_norm_legit_no_training.default, + torch.ops.aten.batch_norm.default, + ): + continue + bn_node = n + n = bn_node.args[0] + if not _is_conv_or_conv_transpose_node(n): + continue + conv_node = n + conv_weight_node = conv_node.args[1] + conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None + fold_bn_weights_into_conv_node( + conv_node, conv_weight_node, conv_bias_node, bn_node, m + ) + + m.graph.eliminate_dead_code() + m.recompile() + + +def _get_node_name_to_scope(model: GraphModule) -> dict[str, tuple[str, type]]: + # TODO: move this information to fx node itself + node_name_to_scope: dict[str, tuple[str, type]] = {} + for n in model.graph.nodes: + nn_module_stack = n.meta.get("nn_module_stack", None) + current_scope = ("", type(None)) + if nn_module_stack: + bt = list(nn_module_stack.values())[-1] + current_scope = (bt[0].split(".")[-1], bt[1]) + node_name_to_scope[n.name] = current_scope + return node_name_to_scope + + +def _get_aten_graph_module_for_pattern( + pattern: Callable, + example_inputs: tuple[Any, ...], + is_cuda: bool = False, + **kwargs, +) -> GraphModule: + """ + Convert the pattern to an FX graph with decomposed aten ops. + """ + if is_cuda: + example_inputs = tuple( + [x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs] + ) + + aten_pattern = torch.export.export_for_training( + pattern, # type: ignore[arg-type] + example_inputs, + kwargs, + ).module() + + aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] + aten_pattern.recompile() # type: ignore[operator] + + # ep.module() adds copy_ nodes for the mutated inputs. + # For patterns, it doesn't matter + for node in aten_pattern.graph.nodes: # type: ignore[union-attr] + if ( + node.op == "call_function" + and node.target == torch.ops.aten.copy_.default + and len(node.users) == 0 + ): + aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] + + aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] + aten_pattern.recompile() # type: ignore[operator] + + return aten_pattern # type: ignore[return-value] + + +def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None: + """Remove .tensor overload for quantize/dequantize ops so that we can + use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e + """ + _MAP = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel, + torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel, + torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp, + } + for n in match_pattern.graph.nodes: + if n.op != "call_function": + continue + if n.target in _MAP: + n.target = _MAP[n.target] + + +def _is_literal(arg): + if isinstance(arg, (int, float)): + return True + if isinstance(arg, (tuple, list)): + return all(map(_is_literal, arg)) + return False + + +def _replace_literals_with_new_placeholders( + gm: torch.fx.GraphModule, + merge_dup: bool = False, + exclude_literals: Optional[list[Any]] = None, +): + """Replace the literals in the graph with placeholder nodes that's created on the fly while we + traverse the graph, so that the literal arguments in the graph can be matched and replaced + + To use this, the pattern and replacement graph should have the exact same number of literal args + and they should be used in the exact same order in the pattern and replacement graph. + + If the literal arguments are not used in the same order in pattern and replacement graph, please + use `_replace_literals_with_existing_placeholders` instead + + Args: + `gm`: input GraphModule that we'll transform + `merge_dup`: boolean flag to indicate that if the same literal appears multiple times in + the graph, whether they should correspond to the same placeholder or not + `exclude_literals`: a list of literals that will not be replaced with placeholders + + Example: + + # 1. Original Graph + def pattern(self, x): + return x + 3 + + def replacement(self, x): + return x - 3 + + example_inputs = (torch.randn(1, 3, 3, 3),) + pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs) + replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus) + + # 2. Before calling replace literals we'll see the following graph: + def pattern(self, x): + return x + 3 + + def replacement(self, x): + return x - 3 + + pattern_gm = _replace_literals_with_new_placeholders(pattern_gm) + replacement_gm = _replace_literals_with_new_placeholders(replacement_gm) + + # 3. After replacing literals with new placeholder nodes + + def pattern(self, x, new_ph): + return x + new_ph + + def pattern(self, x, new_ph): + return x - new_ph + + """ + last_ph = None + cnt = 0 + literal_to_ph: dict[Union[float, bool, int, torch.dtype], Node] = {} + if exclude_literals is None: + exclude_literals = [] + + in_spec = gm._in_spec + args_spec = in_spec.children_specs[0] + for node in gm.graph.nodes: + if node.op == "placeholder": + last_ph = node + cnt += 1 + continue + with gm.graph.inserting_after(last_ph): + new_args = [] + for arg in node.args: + if _is_literal(arg) and arg not in exclude_literals: + if merge_dup and arg in literal_to_ph: + new_args.append(literal_to_ph[arg]) + else: + ph_node = gm.graph.placeholder("arg" + str(cnt)) + new_args.append(ph_node) + args_spec.children_specs.append(LeafSpec()) + cnt += 1 + if merge_dup: + literal_to_ph[arg] = ph_node + else: + new_args.append(arg) + new_args = tuple(new_args) + + node.args = new_args + + # Update `num_nodes`, `num_leaves`, `num_children`. + args_spec.__post_init__() + in_spec.__post_init__() + return gm + + +def _replace_literals_with_existing_placeholders( + gm: torch.fx.GraphModule, + exclude_literals: Optional[list[Any]] = None, + literal_to_ph_idx: Optional[dict[Union[float, int, bool, torch.dtype], int]] = None, +): + """Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments + in the graph can be matched and replaced + + To use this, all literal args in the graph should be unique and each of them should correspond + to exactly one placeholder node + + # 1. Original Graph + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + x_i8 = torch.clamp(x_i8, quant_min, quant_max) + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + + example_inputs = ( + torch.randn(1, 3, 3, 3), + 1.0, + 0, + -128, + 127, + ) + pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs) + replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus) + + # 2. Before calling replace literals we'll see the following graph: + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + x_i8 = torch.clamp(x_i8, -128, 127) + return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32) + + # Note that literal args appear in different order in pattern and replacement graph, so + # we can't use _replace_literals_with_new_placeholders + + literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4} + pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx) + replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx) + + # 3. After replacing literals with existing placeholder nodes + + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + x_i8 = torch.clamp(x_i8, quant_min, quant_max) + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + """ + if exclude_literals is None: + exclude_literals = [] + + if literal_to_ph_idx is None: + literal_to_ph_idx = {} + + phs = [node for node in gm.graph.nodes if node.op == "placeholder"] + + for node in gm.graph.nodes: + if node.op != "call_function": + continue + new_args = [] + for arg in node.args: + if ( + _is_literal(arg) + and arg not in exclude_literals + and arg in literal_to_ph_idx + ): + ph_idx = literal_to_ph_idx[arg] + ph_node = phs[ph_idx] + new_args.append(ph_node) + else: + new_args.append(arg) + new_args = tuple(new_args) + node.args = new_args + return gm + + +# TODO: Handle this in export itself and don't wrap the model in another GraphModule +# in prepare and convert +def _disallow_eval_train(model: GraphModule): + """ + Disallow calling `model.train()` or `model.eval()` on the given GraphModule. + This is useful for exported models, where these methods don't actually behave as expected. + """ + error_message = """ + Calling train() or eval() is not supported for exported models. + Please call `torchao.quantization.pt2e_flow.move_exported_model_to_train(model)` (or eval) instead. + + If you cannot replace the calls to `model.train()` and `model.eval()`, you may override + the behavior for these methods by calling `torchao.quantization.pt2e_flow.allow_exported_model_train_eval(model)`, + which does the above automatically for you. Note that this has limited effect on switching + behavior between train and eval modes, and should be used only for special ops such as dropout + and batchnorm. + """ + + def _train(self, mode: bool = True): + raise NotImplementedError(error_message) + + def _eval(self, mode: bool = True): + raise NotImplementedError(error_message) + + model.train = types.MethodType(_train, model) # type: ignore[method-assign] + model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] + return model diff --git a/torchao/quantization/pt2e_flow/qconfig.py b/torchao/quantization/pt2e_flow/qconfig.py new file mode 100644 index 0000000000..c785ff2b1a --- /dev/null +++ b/torchao/quantization/pt2e_flow/qconfig.py @@ -0,0 +1,699 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from collections import namedtuple +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.ao.quantization.fake_quantize import ( + FakeQuantize, + FakeQuantizeBase, + FusedMovingAvgObsFakeQuantize, + default_dynamic_fake_quant, + default_embedding_fake_quant, + default_embedding_fake_quant_4bit, + default_fake_quant, + default_fused_act_fake_quant, + default_fused_per_channel_wt_fake_quant, + default_fused_wt_fake_quant, + default_per_channel_weight_fake_quant, + default_weight_fake_quant, + fused_per_channel_wt_fake_quant_range_neg_127_to_127, + fused_wt_fake_quant_range_neg_127_to_127, +) +from typing_extensions import deprecated + +from .observer import ( + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + NoopObserver, + ObserverBase, + PlaceholderObserver, + ReuseInputObserver, + _PartialWrapper, + default_debug_observer, + default_dynamic_quant_observer, + default_float_qparams_observer, + default_float_qparams_observer_4bit, + default_observer, + default_per_channel_weight_observer, + default_placeholder_observer, + default_reuse_input_observer, + default_weight_observer, + per_channel_weight_observer_range_neg_127_to_127, + weight_observer_range_neg_127_to_127, +) + +__all__ = [ + "QConfig", + # TODO: deprecated, remove + "QConfigDynamic", + "default_qconfig", + "default_debug_qconfig", + "default_per_channel_qconfig", + "default_dynamic_qconfig", + "float16_dynamic_qconfig", + "float16_static_qconfig", + "per_channel_dynamic_qconfig", + "float_qparams_weight_only_qconfig", + "float_qparams_weight_only_qconfig_4bit", + "default_quint8_weight_qconfig", + "default_qat_qconfig", + "default_dynamic_qat_qconfig", + "default_weight_only_qconfig", + "default_activation_only_qconfig", + "default_qat_qconfig_v2", + "default_reuse_input_qconfig", + "default_symmetric_qnnpack_qconfig", + "default_per_channel_symmetric_qnnpack_qconfig", + "default_symmetric_qnnpack_qat_qconfig", + "default_per_channel_symmetric_qnnpack_qat_qconfig", + "default_embedding_qat_qconfig", + "default_embedding_qat_qconfig_4bit", + "get_default_qconfig", + "get_default_qat_qconfig", + "get_default_qconfig_dict", + "get_default_qat_qconfig_dict", + "QConfigAny", + "qconfig_equals", +] + + +class QConfig(namedtuple("QConfig", ["activation", "weight"])): + """ + Describes how to quantize a layer or a part of the network by providing + settings (observer classes) for activations and weights respectively. + + + Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns + instances on invocation, not the concrete observer instances themselves. + Quantization preparation function will instantiate observers multiple times for each of the layers. + + + Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` + method (that behaves like functools.partial):: + + my_qconfig = QConfig( + activation=MinMaxObserver.with_args(dtype=torch.qint8), + weight=default_observer.with_args(dtype=torch.qint8)) + + """ + + __slots__ = () + + def __new__(cls, activation, weight): + # catch common mistakes + if isinstance(activation, nn.Module) or isinstance(weight, nn.Module): + raise ValueError( + "QConfig received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" + ) + return super().__new__(cls, activation, weight) + + +@deprecated( + "`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead", + category=FutureWarning, +) +class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])): + """ + Describes how to dynamically quantize a layer or a part of the network by providing + settings (observer classes) for weights. + + It's like QConfig, but for dynamic quantization. + + Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns + instances on invocation, not the concrete observer instances themselves. + Quantization function will instantiate observers multiple times for each of the layers. + + Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` + method (that behaves like functools.partial):: + + my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8)) + """ + + __slots__ = () + + def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): + # catch common mistakes + if isinstance(weight, nn.Module): + raise ValueError( + "QConfigDynamic received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" + ) + return super().__new__(cls, activation, weight) + + +default_qconfig = QConfig(activation=default_observer, weight=default_weight_observer) +""" +Default qconfig configuration. +""" + +default_debug_qconfig = QConfig( + weight=default_weight_observer, activation=default_debug_observer +) +""" +Default qconfig configuration for debugging. +""" + +default_per_channel_qconfig = QConfig( + activation=default_observer, weight=default_per_channel_weight_observer +) +""" +Default qconfig configuration for per channel weight quantization. +""" + +default_dynamic_qconfig = QConfig( + activation=default_dynamic_quant_observer, weight=default_weight_observer +) +""" +Default dynamic qconfig. +""" + +float16_dynamic_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float16, is_dynamic=True), + weight=PlaceholderObserver.with_args(dtype=torch.float16), +) +""" +Dynamic qconfig with weights quantized to `torch.float16`. +""" + +float16_static_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float16), + weight=PlaceholderObserver.with_args(dtype=torch.float16), +) +""" +Dynamic qconfig with both activations and weights quantized to `torch.float16`. +""" + +per_channel_dynamic_qconfig = QConfig( + activation=default_dynamic_quant_observer, + weight=default_per_channel_weight_observer, +) +""" +Dynamic qconfig with weights quantized per channel. +""" + +float_qparams_weight_only_qconfig = QConfig( + activation=default_placeholder_observer, weight=default_float_qparams_observer +) +""" +Dynamic qconfig with weights quantized with a floating point zero_point. +""" + +float_qparams_weight_only_qconfig_4bit = QConfig( + activation=default_placeholder_observer, weight=default_float_qparams_observer_4bit +) + +default_qat_qconfig = QConfig( + activation=default_fake_quant, weight=default_weight_fake_quant +) +""" +Default qconfig for QAT. +""" + +default_dynamic_qat_qconfig = QConfig( + activation=default_dynamic_fake_quant, weight=default_weight_fake_quant +) +""" +Default qconfig for dynamic QAT. +""" + +default_weight_only_qconfig = QConfig( + activation=torch.nn.Identity, weight=default_weight_fake_quant +) +""" +Default qconfig for quantizing weights only. +""" + +default_activation_only_qconfig = QConfig( + activation=default_fake_quant, weight=torch.nn.Identity +) +""" +Default qconfig for quantizing activations only. +""" + +# QAT config that uses a fused observer + fake quant modules for optimized training performance. +# to modify the activation/weight observers, the default entries in fake_quantize.py can be modified. +default_qat_qconfig_v2 = QConfig( + activation=default_fused_act_fake_quant, weight=default_fused_wt_fake_quant +) +""" +Fused version of `default_qat_config`, has performance benefits. +""" + +default_reuse_input_qconfig = QConfig( + activation=default_reuse_input_observer, weight=NoopObserver +) +""" +Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape +""" + + +def get_default_qconfig(backend="x86", version=0): + """ + Returns the default PTQ qconfig for the specified backend. + + Args: + * `backend` (str): a string representing the target backend. Currently supports + `x86` (default), `fbgemm`, `qnnpack` and `onednn`. + + Return: + qconfig + """ + supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] + if backend not in supported_backends: + raise AssertionError( + "backend: " + + str(backend) + + f" not supported. backend must be one of {supported_backends}" + ) + + if version == 0: + if backend == "fbgemm": + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=True), + weight=default_per_channel_weight_observer, + ) + elif backend == "qnnpack": + # TODO: make this compatible with xnnpack constraints + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=False), + weight=default_weight_observer, + ) + elif backend == "onednn": + if not torch.cpu._is_vnni_supported(): + warnings.warn( + "Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues " + "on CPU without Vector Neural Network Instruction support." + ) + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=False), + weight=default_per_channel_weight_observer, + ) + elif backend == "x86": + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=True), + weight=default_per_channel_weight_observer, + ) + else: + # won't reach + qconfig = default_qconfig + else: + raise AssertionError( + "Version number: " + + str(version) + + " in get_default_qconfig is not supported. Version number must be 0" + ) + + return qconfig + + +""" +Default, symmetric PTQ qconfig for the specified backend. And a per_channel +variant of the same. + +Symmetric here applies to signed weights with zero point = 0, and additional +value restrictions. The activations are also signed 8-bit integers with this +qconfig. + + * Once this change is merged [as of 3/17/22], with backend or qengine = + 'qnnpack', some quantized operators with this symmetric qconfig may use + operators from xnnpack library. + + ** Support to use xnnpack ops with `qnnpack` backed for asymmetric + qconfig (returned by get_default_qconfig()) is not available yet. + + * This qconfig uses signed activations and weights. Weights have added + restrictions such as zero point is forced to be 0, making the weights + symmetric, hence the name. And the 8-bit quantized values are + restricting to to [-127, +127], excluding -128. + + * xnnpack has a requantization scale value restriction, 0x1p-32 <= + requantization_scale < 256.0 where, `requantization_scale = (input_scale + * kernel_scale) / (output_scale)`. Using this eps (w/ assumed max value + of 256) is to prevent requantization_scale to go below xnnpack lower + threshold. +""" +default_symmetric_qnnpack_qconfig = QConfig( + activation=HistogramObserver.with_args( + dtype=torch.qint8, reduce_range=False, eps=2**-12 + ), + weight=weight_observer_range_neg_127_to_127, +) + +default_per_channel_symmetric_qnnpack_qconfig = QConfig( + activation=HistogramObserver.with_args( + dtype=torch.qint8, reduce_range=False, eps=2**-12 + ), + weight=per_channel_weight_observer_range_neg_127_to_127, +) + +default_embedding_qat_qconfig = QConfig( + activation=NoopObserver.with_args(dtype=torch.float32), + weight=default_embedding_fake_quant, +) + +default_embedding_qat_qconfig_4bit = QConfig( + activation=NoopObserver.with_args(dtype=torch.float32), + weight=default_embedding_fake_quant_4bit, +) + +default_quint8_weight_qconfig = QConfig( + activation=HistogramObserver, weight=MinMaxObserver +) + + +def get_default_qat_qconfig(backend="x86", version=1): + """ + Returns the default QAT qconfig for the specified backend. + + Args: + * `backend` (str): a string representing the target backend. Currently supports + `x86` (default), `fbgemm`, `qnnpack` and `onednn`. + * `version`: version, for backwards compatibility. Can be `None` or `1`. + + Return: + qconfig + """ + supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] + if backend not in supported_backends: + raise AssertionError( + "backend: " + + str(backend) + + f" not supported. backend must be one of {supported_backends}" + ) + + # Histogram observer is too slow for quantization aware training + if version == 0: + if backend == "fbgemm": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_per_channel_weight_fake_quant, + ) + elif backend == "qnnpack": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=False, + ), + weight=default_weight_fake_quant, + ) + elif backend == "onednn": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255 + ), + weight=default_per_channel_weight_fake_quant, + ) + elif backend == "x86": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_per_channel_weight_fake_quant, + ) + else: + qconfig = default_qat_qconfig + # Use the fused observe + fake_quant modules for doing QAT. + elif version == 1: + if backend == "fbgemm": + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_fused_per_channel_wt_fake_quant, + ) + elif backend == "qnnpack": + # TODO: make this compatible with xnnpack constraints + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=False, + ), + weight=default_fused_wt_fake_quant, + ) + elif backend == "onednn": + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255 + ), + weight=default_fused_per_channel_wt_fake_quant, + ) + elif backend == "x86": + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_fused_per_channel_wt_fake_quant, + ) + else: + qconfig = default_qat_qconfig_v2 + else: + raise AssertionError( + "Version number: " + + str(version) + + "in get_default_qat_qconfig is not supported. Version number must be 0 or 1" + ) + + return qconfig + + +""" +Default symmetric QAT qconfig for qnnpack. And its per channel weight variant. +""" +default_symmetric_qnnpack_qat_qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + reduce_range=False, + eps=2**-12, + ), + weight=fused_wt_fake_quant_range_neg_127_to_127, +) + +default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + reduce_range=False, + eps=2**-12, + ), + weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127, +) + +_default_fp32_placeholder_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float32), + weight=PlaceholderObserver.with_args(dtype=torch.float32), +) + +_default_quint8_placeholder_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.quint8), + # operators using this qconfig doesn't have weights + weight=None, +) + + +@deprecated( + "`torch.ao.quantization.get_default_qconfig_dict` is deprecated and will be removed in " + "a future version. Please use `torch.ao.quantization.get_default_qconfig_mapping` instead.", + category=FutureWarning, +) +def get_default_qconfig_dict(backend="x86", version=0): + return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict() + + +@deprecated( + "`torch.ao.quantization.get_default_qat_qconfig_dict` is deprecated and will be removed in " + "a future version. Please use `torch.ao.quantization.get_default_qat_qconfig_mapping` instead.", + category=FutureWarning, +) +def get_default_qat_qconfig_dict(backend="x86", version=1): + return torch.ao.quantization.get_default_qat_qconfig_mapping( + backend, version + ).to_dict() + + +def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> None: + """ + Verifies that this `qconfig` is valid. + """ + if qconfig is None: + return + is_conv_transpose_mod = isinstance( + mod, + (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d), + ) + if is_conv_transpose_mod: + if qconfig.weight is None: + # for now, we assume that any qconfig for ConvTranspose without a weight is valid + return + example_observer = qconfig.weight() + is_per_channel = isinstance( + example_observer, + ( + torch.ao.quantization.PerChannelMinMaxObserver, + torch.ao.quantization.MovingAveragePerChannelMinMaxObserver, + ), + ) + assert ( + not is_per_channel + ), "Per channel weight observer is not supported yet for ConvTranspose{n}d." + + +QConfigAny = Optional[QConfig] +QConfigAny.__module__ = "torch.ao.quantization.qconfig" + + +def _add_module_to_qconfig_obs_ctr( + qconfig: QConfigAny, module: Optional[nn.Module] +) -> Any: + r"""This is a helper function for use in quantization prepare that updates a qconfig so that + the constructors stored in the qconfig will create observers on the same device that + 'module' is on. This is intended to be used when the qconfigs are propagated to each + module in order to avoid potential device alignment issues. + + Args: + qconfig: QConfig with obs constructors stored in activation and weight + module: module which the qconfig is related to + + Return: + qconfig: configured so that obs constructors set to construct on the same device as module + """ + + if module is None or qconfig is None or qconfig._fields != ("activation", "weight"): + return qconfig + + def get_factory_kwargs_based_on_module_device(): + assert isinstance(module, torch.nn.Module) + devices = {p.device for p in module.parameters()} | { + p.device for p in module.buffers() + } + device = next(iter(devices)) if len(devices) > 0 else None + return None if device is None else {"device": device} + + def configure_constructor_to_put_obs_on_module_device(original_constructor): + try: + # check if constructor can accept factory_kwargs + check = original_constructor.with_args(factory_kwargs=None) + check() + return original_constructor.with_callable_args( + factory_kwargs=get_factory_kwargs_based_on_module_device + ) + except AttributeError: # qconfig doesn't have activation or weight + return original_constructor + except TypeError: # the class doesn't accept factory_kwargs argument + return original_constructor + + activation = configure_constructor_to_put_obs_on_module_device(qconfig.activation) + weight = configure_constructor_to_put_obs_on_module_device(qconfig.weight) + + return QConfig(activation, weight) + + +_ObserverOrFakeQuantizeConstructor = Union[ + _PartialWrapper, type[ObserverBase], type[FakeQuantizeBase] +] + + +def _obs_or_fq_ctr_equals( + obs_or_fq1: _ObserverOrFakeQuantizeConstructor, + obs_or_fq2: _ObserverOrFakeQuantizeConstructor, +): + if isinstance(obs_or_fq1, _PartialWrapper) and isinstance( + obs_or_fq2, _PartialWrapper + ): + return _partial_wrapper_equals(obs_or_fq1, obs_or_fq2) + return obs_or_fq1 == obs_or_fq2 + + +def _partial_wrapper_equals(obs_or_fq1: _PartialWrapper, obs_or_fq2: _PartialWrapper): + """ + Return whether the two partial wrappers are equal, + """ + # functools.partial has no __eq__ operator defined so '==' defaults to 'is' + obs_or_fq1_keywords = copy.copy(obs_or_fq1.p.keywords) + obs_or_fq2_keywords = copy.copy(obs_or_fq2.p.keywords) + keywords_equal = True + # compare observer constructor with _obs_or_fq_ctr_equals since direct compare would fail + if "observer" in obs_or_fq1_keywords and "observer" in obs_or_fq2_keywords: + keywords_equal = keywords_equal and _obs_or_fq_ctr_equals( + obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"] + ) + obs_or_fq1_keywords.pop("observer") + obs_or_fq2_keywords.pop("observer") + keywords_equal = keywords_equal and obs_or_fq1_keywords == obs_or_fq2_keywords + return ( + obs_or_fq1.p.func == obs_or_fq2.p.func + and obs_or_fq1.p.args == obs_or_fq2.p.args + and keywords_equal + ) + + +def qconfig_equals(q1: QConfigAny, q2: QConfigAny): + """ + Returns `True` if `q1` equals `q2`, and `False` otherwise. + """ + if q1 is None or q2 is None: + return q1 == q2 + else: + assert q1 is not None and q2 is not None + try: + # Qconfig weight and activation can be either a partial wrapper, + # or an observer class. Special handling is required (above) for + # comparing partial wrappers. + activation_same = _obs_or_fq_ctr_equals(q1.activation, q2.activation) + weight_same = _obs_or_fq_ctr_equals(q1.weight, q2.weight) + return activation_same and weight_same + except AttributeError: + return q1 == q2 + + +def _activation_is_memoryless(qconfig: QConfig): + """ + Return whether the observer for activations defined in the given QConfig is memoryless. + This means a MovingAverage observer with averaging constant equal to 1. + """ + + def _is_memoryless(observer): + return ( + hasattr(observer, "averaging_constant") and observer.averaging_constant == 1 + ) + + act = qconfig.activation() + if isinstance(act, FakeQuantizeBase) and hasattr(act, "activation_post_process"): + return _is_memoryless(act.activation_post_process) + else: + return _is_memoryless(act) + + +def _is_reuse_input_qconfig(qconfig: Optional[QConfig]): + return ( + qconfig is not None + and isinstance(qconfig.activation(), ReuseInputObserver) + and isinstance(qconfig.weight(), NoopObserver) + ) diff --git a/torchao/quantization/pt2e_flow/quantize_pt2e.py b/torchao/quantization/pt2e_flow/quantize_pt2e.py new file mode 100644 index 0000000000..e856a40453 --- /dev/null +++ b/torchao/quantization/pt2e_flow/quantize_pt2e.py @@ -0,0 +1,266 @@ +import torch +from torch._export.passes.constant_folding import constant_fold +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_manager import PassManager + +from torchao.quantization.pt2e_flow.pt2e.convert import ( + _convert_to_reference_decomposed_fx, +) +from torchao.quantization.pt2e_flow.pt2e.duplicate_dq_pass import DuplicateDQPass +from torchao.quantization.pt2e_flow.pt2e.port_metadata_pass import PortNodeMetaForQDQ +from torchao.quantization.pt2e_flow.pt2e.prepare import prepare +from torchao.quantization.pt2e_flow.pt2e.qat_utils import ( + _fold_conv_bn_qat, + _fuse_conv_bn_qat, +) +from torchao.quantization.pt2e_flow.pt2e.representation import ( + reference_representation_rewrite, +) +from torchao.quantization.pt2e_flow.pt2e.utils import ( + _disallow_eval_train, + _fuse_conv_bn_, + _get_node_name_to_scope, +) +from torchao.quantization.pt2e_flow.quantizer import ( # noqa: F401 + DerivedQuantizationSpec, + FixedQParamsQuantizationSpec, + QuantizationAnnotation, + QuantizationSpec, + QuantizationSpecBase, + Quantizer, + SharedQuantizationSpec, +) + +__all__ = [ + "prepare_pt2e", + "prepare_qat_pt2e", + "convert_pt2e", +] + + +def prepare_pt2e( + model: GraphModule, + quantizer: Quantizer, +) -> GraphModule: + """Prepare a model for post training quantization + + Args: + * `model` (torch.fx.GraphModule): a model captured by `torch.export.export_for_training` API. + * `quantizer`: A backend specific quantizer that conveys how user want the + model to be quantized. Tutorial for how to write a quantizer can be found here: + https://pytorch.org/tutorials/prototype/pt2e_quantizer.html + + Return: + A GraphModule with observer (based on quantizer annotation), ready for calibration + + Example:: + + import torch + from torchao.quantization.pt2e_flow.quantize_pt2e import prepare_pt2e + from torchao.quantization.pt2e_flow.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define calibration function + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result shoud mostly stay the same + m = torch.export.export_for_training(m, *example_inputs).module() + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = prepare_pt2e(m, quantizer) + + # run calibration + # calibrate(m, sample_inference_data) + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e") + original_graph_meta = model.meta + node_name_to_scope = _get_node_name_to_scope(model) + # TODO: check qconfig_mapping to make sure conv and bn are both configured + # to be quantized before fusion + # TODO: (maybe) rewrite this with subgraph_rewriter + _fuse_conv_bn_(model) + model = quantizer.transform_for_annotation(model) + quantizer.annotate(model) + quantizer.validate(model) + model = prepare( + model, + node_name_to_scope, + is_qat=False, + obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback, + ) + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model + + +def prepare_qat_pt2e( + model: GraphModule, + quantizer: Quantizer, +) -> GraphModule: + """Prepare a model for quantization aware training + + Args: + * `model` (torch.fx.GraphModule): see :func:`~torchao.quantization.pt2e_flow.quantize_pt2e.prepare_pt2e` + * `quantizer`: see :func:`~torchao.quantization.pt2e_flow.quantize_pt2e.prepare_pt2e` + + Return: + A GraphModule with fake quant modules (based on quantizer annotation), ready for + quantization aware training + + Example:: + import torch + from torchao.quantization.pt2e_flow.quantize_pt2e import prepare_qat_pt2e + from torchao.quantization.pt2e_flow.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define the training loop for quantization aware training + def train_loop(model, train_data): + model.train() + for image, target in data_loader: + ... + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result shoud mostly stay the same + m = torch.export.export_for_training(m, *example_inputs).module() + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = prepare_qat_pt2e(m, quantizer) + + # run quantization aware training + train_loop(prepared_model, train_loop) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e") + original_graph_meta = model.meta + node_name_to_scope = _get_node_name_to_scope(model) + model = quantizer.transform_for_annotation(model) + quantizer.annotate(model) + quantizer.validate(model) + # Perform fusion after annotate to avoid quantizing ops in the new + # subgraph that don't need to be quantized + # TODO: only fuse if conv and bn are both configured to be quantized + _fuse_conv_bn_qat(model) + model = prepare( + model, + node_name_to_scope, + is_qat=True, + obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback, + ) + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model + + +_QUANT_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.torchao_quant.quantize_affine, +] + + +def _quant_node_constraint(n: Node) -> bool: + """If there is any pure ops between get_attr and quantize op they will be const propagated + e.g. get_attr(weight) -> transpose -> quantize -> dequantize* + (Note: dequantize op is not going to be constant propagated) + + This filter is added because we don't want to constant fold the things that are not + related to quantization + """ + return n.op == "call_function" and n.target in _QUANT_OPS + + +def convert_pt2e( + model: GraphModule, + use_reference_representation: bool = False, + fold_quantize: bool = True, +) -> GraphModule: + """Convert a calibrated/trained model to a quantized model + + Args: + * `model` (torch.fx.GraphModule): calibrated/trained model + * `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not + * `fold_quantize` (bool): boolean flag for whether fold the quantize op or not + + Returns: + quantized model, either in q/dq representation or reference representation + + Example:: + + # prepared_model: the model produced by `prepare_pt2e`/`prepare_qat_pt2e` and calibration/training + # `convert_pt2e` produces a quantized model that represents quantized computation with + # quantize dequantize ops and fp32 ops by default. + # Please refer to + # https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model + # for detailed explanation of output quantized model + quantized_model = convert_pt2e(prepared_model) + + """ # flake8: noqa + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e") + if not isinstance(use_reference_representation, bool): + raise ValueError( + "Unexpected argument type for `use_reference_representation`, " + f"please make sure you intend to pass argument {use_reference_representation} to convert_pt2e" + ) + original_graph_meta = model.meta + model = _convert_to_reference_decomposed_fx(model) + model = _fold_conv_bn_qat(model) + + pm = PassManager([DuplicateDQPass()]) + model = pm(model).graph_module + + pm = PassManager([PortNodeMetaForQDQ()]) + model = pm(model).graph_module + + if fold_quantize: + constant_fold(model, _quant_node_constraint) + + if use_reference_representation: + model = reference_representation_rewrite(model) + + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model diff --git a/torchao/quantization/pt2e_flow/quantizer/__init__.py b/torchao/quantization/pt2e_flow/quantizer/__init__.py new file mode 100644 index 0000000000..e65652573b --- /dev/null +++ b/torchao/quantization/pt2e_flow/quantizer/__init__.py @@ -0,0 +1,21 @@ +from .quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + FixedQParamsQuantizationSpec, + QuantizationAnnotation, + QuantizationSpec, + QuantizationSpecBase, + Quantizer, + SharedQuantizationSpec, +) + +__all__ = [ + "EdgeOrNode", + "Quantizer", + "QuantizationSpecBase", + "QuantizationSpec", + "FixedQParamsQuantizationSpec", + "SharedQuantizationSpec", + "DerivedQuantizationSpec", + "QuantizationAnnotation", +] diff --git a/torchao/quantization/pt2e_flow/quantizer/composable_quantizer.py b/torchao/quantization/pt2e_flow/quantizer/composable_quantizer.py new file mode 100644 index 0000000000..a90241d38e --- /dev/null +++ b/torchao/quantization/pt2e_flow/quantizer/composable_quantizer.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .quantizer import QuantizationAnnotation, Quantizer + +if TYPE_CHECKING: + import torch + from torch.fx import Node + +__all__ = [ + "ComposableQuantizer", +] + + +class ComposableQuantizer(Quantizer): + """ + ComposableQuantizer allows users to combine more than one quantizer into a single quantizer. + This allows users to quantize a model with multiple quantizers. E.g., embedding quantization + maybe supported by one quantizer while linear layers and other ops might be supported by another + quantizer. + + ComposableQuantizer is initialized with a list of `Quantizer` instances. + The order of the composition matters since that is the order in which the quantizers will be + applies. + Example: + ``` + embedding_quantizer = EmbeddingQuantizer() + linear_quantizer = MyLinearQuantizer() + xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers + composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer]) + prepared_m = prepare_pt2e(model, composed_quantizer) + ``` + """ + + def __init__(self, quantizers: list[Quantizer]): + super().__init__() + self.quantizers = quantizers + self._graph_annotations: dict[Node, QuantizationAnnotation] = {} + + def _record_and_validate_annotations( + self, gm: torch.fx.GraphModule, quantizer: Quantizer + ) -> None: + for n in gm.graph.nodes: + if "quantization_annotation" in n.meta: + # check if the annotation has been changed by + # comparing QuantizationAnnotation object id + if n in self._graph_annotations and ( + id(self._graph_annotations[n]) + != id(n.meta["quantization_annotation"]) + ): + raise RuntimeError( + f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}" + ) + else: + self._graph_annotations[n] = n.meta["quantization_annotation"] + else: + if n in self._graph_annotations: + raise RuntimeError( + f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}" + ) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + for quantizer in self.quantizers: + quantizer.annotate(model) + self._record_and_validate_annotations(model, quantizer) + return model + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + for quantizer in self.quantizers: + model = quantizer.transform_for_annotation(model) + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass diff --git a/torchao/quantization/pt2e_flow/quantizer/embedding_quantizer.py b/torchao/quantization/pt2e_flow/quantizer/embedding_quantizer.py new file mode 100644 index 0000000000..5109a4d7e0 --- /dev/null +++ b/torchao/quantization/pt2e_flow/quantizer/embedding_quantizer.py @@ -0,0 +1,97 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import copy + +import torch +import torch.nn.functional as F + +from torchao.quantization.pt2e_flow.observer import PerChannelMinMaxObserver +from torchao.quantization.pt2e_flow.quantizer.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, +) +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer_utils import ( + OperatorConfig, + OperatorPatternType, + QuantizationConfig, +) + +__all__ = [ + "get_embedding_operators_config", + "EmbeddingQuantizer", +] + + +def get_embedding_operators_config() -> OperatorConfig: + weight_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(eps=2**-12), + ) + quantization_config = QuantizationConfig(None, None, weight_quantization_spec, None) + ops: list[OperatorPatternType] = [[torch.nn.Embedding]] + ops.append([F.embedding]) + supported_config_and_operators = OperatorConfig( + config=quantization_config, operators=ops + ) + return copy.deepcopy(supported_config_and_operators) + + +class EmbeddingQuantizer(Quantizer): + def __init__(self) -> None: + super().__init__() + + @classmethod + def get_supported_quantization_configs(cls) -> list[QuantizationConfig]: + op_configs: set[QuantizationConfig] = { + spec for spec, _ in cls.get_supported_operators() + } + return list(op_configs) + + @classmethod + def get_supported_operator_for_quantization_config( + cls, quantization_config: QuantizationConfig + ) -> list[OperatorPatternType]: + for config, ops in cls.get_supported_operators(): + # note: this assumes each entry in cls.supported_spec_and_operators + # corresponds to one spec, e.g. we don't have + # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] + # where the first and second entry have the same spec but did not + # merge the op list + if config == quantization_config: + return ops + return [] + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + self._annotate_embedding_ops(model.graph) + return model + + def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None: + embedding_config: OperatorConfig = get_embedding_operators_config() + for node in graph.nodes: + # Keep node parsing based annotations instead of module partitioners + # just as an example of alternate ways of annotating + if ( + node.op == "call_function" + and node.target == torch.ops.aten.embedding.default + ): + if embedding_config.config.weight is None: + raise ValueError( + "Embedding config must have a valid weight quantization spec." + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + node.args[0]: embedding_config.config.weight, + } + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> list[OperatorConfig]: + return [get_embedding_operators_config()] diff --git a/torchao/quantization/pt2e_flow/quantizer/quantizer.py b/torchao/quantization/pt2e_flow/quantizer/quantizer.py new file mode 100644 index 0000000000..d019235ca9 --- /dev/null +++ b/torchao/quantization/pt2e_flow/quantizer/quantizer.py @@ -0,0 +1,180 @@ +# mypy: allow-untyped-defs +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Callable, Optional, Union + +import torch +from torch import Tensor +from torch.ao.quantization import ObserverOrFakeQuantize +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor +from torch.fx import Node + +__all__ = [ + "Quantizer", + "QuantizationSpecBase", + "QuantizationSpec", + "FixedQParamsQuantizationSpec", + "EdgeOrNode", + "SharedQuantizationSpec", + "DerivedQuantizationSpec", + "QuantizationAnnotation", +] + + +class QuantizationSpecBase(ABC): # noqa: B024 + """Base class for different types of quantization specs that allows users to + specify how to quantize a Tensor (input/output of a Node) in the model + """ + + +@dataclass(eq=True, frozen=True) +class QuantizationSpec(QuantizationSpecBase): + """Quantization spec for common operators that allows user to specify how to + quantize a Tensor, this includes dtype, quant_min, quant_max etc. + """ + + dtype: torch.dtype + # observer or fake_quantize constructor such as + # MinMaxObserver, PerChannelHistogramObserver etc. + # or we can attach some custom args to them + # e.g. MinMaxObserver.with_args(eps=eps) + observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor + quant_min: Optional[int] = None + quant_max: Optional[int] = None + qscheme: Optional[torch.qscheme] = None + ch_axis: Optional[int] = None + is_dynamic: bool = False + + def __post_init__(self): + # TODO: add init for quant_min/quant_max + # quant_min must be less than quant_max + if ( + self.quant_min is not None + and self.quant_max is not None + and self.quant_min > self.quant_max + ): + raise ValueError( + f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}." + ) + + # ch_axis must be less than the number of channels + # but no way to check here. Just check that it is not < 0. + if self.ch_axis is not None and self.ch_axis < 0: + raise ValueError("Ch_axis is < 0.") + + +@dataclass(eq=True, frozen=True) +class FixedQParamsQuantizationSpec(QuantizationSpecBase): + dtype: torch.dtype + scale: float + zero_point: int + quant_min: Optional[int] = None + quant_max: Optional[int] = None + qscheme: Optional[torch.qscheme] = None + is_dynamic: bool = False + + +""" +The way we refer to other points of quantization in the graph will be either +an input edge or an output value +input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node] +output value is an fx Node +""" +EdgeOrNode = Union[tuple[Node, Node], Node] +EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer" + + +@dataclass(eq=True, frozen=True) +class SharedQuantizationSpec(QuantizationSpecBase): + """ + Quantization spec for the Tensors whose quantization parameters are shared with other Tensors + """ + + # the edge or node to share observer or fake quant instances with + edge_or_node: EdgeOrNode + + +@dataclass(eq=True, frozen=True) +class DerivedQuantizationSpec(QuantizationSpecBase): + """Quantization spec for the Tensors whose quantization parameters are derived from other Tensors""" + + derived_from: list[EdgeOrNode] + derive_qparams_fn: Callable[[list[ObserverOrFakeQuantize]], tuple[Tensor, Tensor]] + dtype: torch.dtype + quant_min: Optional[int] = None + quant_max: Optional[int] = None + qscheme: Optional[torch.qscheme] = None + ch_axis: Optional[int] = None + is_dynamic: bool = False + + +@dataclass +class QuantizationAnnotation: + """How are input arguemnt or output should be quantized, + expressed as QuantizationSpec, this corresponds to how a Tensor in the + operator Graph is observed (PTQ) or fake quantized (QAT) + """ + + # a map from torch.fx.Node to a type of QuantizationSpecBase + input_qspec_map: dict[Node, Optional[QuantizationSpecBase]] = field( + default_factory=dict + ) + + # How the output of this node is quantized, expressed as QuantizationSpec + # TODO: change the value to QuantizationSpec in a separate PR + output_qspec: Optional[QuantizationSpecBase] = None + + # For a Node: node1 and edge: (node1, node2), since they are observing the same + # Tensor, we may want to implicitly share observers, this flag allows people to + # turn off this behavior for the output of the node + allow_implicit_sharing: bool = True + + # whether the node is annotated or not + _annotated: bool = False + + +class Quantizer(ABC): + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + """Allows for user defined transforms to run before annotating the graph. + This allows quantizer to allow quantizing part of the model that are otherwise not quantizable. + For example quantizer can + a) decompose a compound operator like scaled dot product attention, + into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa + or b) transform scalars to tensor to allow quantizing scalares. + + Note: this is an optional method + """ + return model + + # annotate nodes in the graph with observer or fake quant constructors + # to convey the desired way of quantization + @abstractmethod + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + pass + + # validate the annotated graph is supported by the backend + @abstractmethod + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + def prepare_obs_or_fq_callback( + self, + model: torch.fx.GraphModule, + edge_or_node_to_obs_or_fq: dict[EdgeOrNode, ObserverOrFakeQuantize], + ) -> None: + """A callback that will be called after the observers or fake quants are created + for each sharing group, but before they are inserted into the graph. The + callback can be used to make final quantization adjustments, such as enforcing + specific scale and zero point on model input or output. + + Args: + * `model`: the graph module being prepared. + * `edge_or_node_to_obs_or_fq`: a dictionary mapping each annotated edge and + node to the corresponding observer or fake quant object. Note that multiple + edges and/or nodes can map to the same observer / fake quant instance if + they were annotated with SharedQuantizationSpec. This dictionary can be + modified by the callback. + """ + return diff --git a/torchao/quantization/pt2e_flow/quantizer/utils.py b/torchao/quantization/pt2e_flow/quantizer/utils.py new file mode 100644 index 0000000000..248acde03f --- /dev/null +++ b/torchao/quantization/pt2e_flow/quantizer/utils.py @@ -0,0 +1,83 @@ +# mypy: allow-untyped-defs + +from torch.fx import Node + +from torchao.quantization.pt2e_flow.pt2e.utils import _is_sym_size_node +from torchao.quantization.pt2e_flow.quantizer.quantizer import QuantizationAnnotation + + +def _annotate_input_qspec_map(node: Node, input_node: Node, qspec): + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + quantization_annotation.input_qspec_map[input_node] = qspec + node.meta["quantization_annotation"] = quantization_annotation + + +def _annotate_output_qspec(node: Node, qspec): + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + quantization_annotation.output_qspec = qspec + node.meta["quantization_annotation"] = quantization_annotation + + +def _node_only_used_for_sym_size(node: Node, partition_nodes: list[Node]): + """ + This utility is used to handle cases when dynami_shape=True tracing leads + to symint nodes in the pattern of linear module. In those cases, we need to + distinguish between the nodes that are in input for just extracting value of + some dimentions (and symint nodes) vs. the one that is activation. + For example: + graph(x, y, weight): + size_0 = torch.ops.aten.sym_size([x], [0]) + size_1 = torch.ops.aten.sym_size([y], [1]) + view_size = size_0 * size_1 + size_3 = torch.ops.aten.sym_size([x], [2]) + vie_out = torch.ops.aten.view(x, [view_size, size_3]) + return mm(view_out, weight) + In the example above y node is not actual input. It exist only to extract size_1 + """ + if _is_sym_size_node(node): + return True + + return all( + ((user not in partition_nodes) or _is_sym_size_node(user)) + for user in node.users + ) + + +def _get_module_name_filter(module_name: str): + """Get the module_name_filter function for a given module name, the filter accepts + a node and checks if the node comes from a module that has certain module name + + For example: + node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 + + + >> module_name_filter = _get_module_name_filter("blocks.sub") + >> print(module_name_filter(node)) + True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" + """ + + def module_name_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + # get_attr nodes doesn't have nn_module_stack? + nn_module_stack = n.meta.get("nn_module_stack", {}) + + def _normalize_path(n): + prefix = 0 + # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph. + if n.startswith("L['self']."): + prefix = len("L['self'].") + return n[prefix:] + + names = [_normalize_path(n) for n, _ in nn_module_stack.values()] + return module_name in names + + return module_name_filter diff --git a/torchao/quantization/pt2e_flow/quantizer/x86_inductor_quantizer.py b/torchao/quantization/pt2e_flow/quantizer/x86_inductor_quantizer.py new file mode 100644 index 0000000000..18d97efe58 --- /dev/null +++ b/torchao/quantization/pt2e_flow/quantizer/x86_inductor_quantizer.py @@ -0,0 +1,1572 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import operator +import warnings +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch.fx import Node +from torch.fx.passes.utils.source_matcher_utils import ( + SourcePartition, + get_source_partitions, +) +from typing_extensions import TypeAlias + +from torchao.quantization.pt2e_flow.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torchao.quantization.pt2e_flow.observer import ( + HistogramObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, +) +from torchao.quantization.pt2e_flow.pt2e.graph_utils import find_sequential_partitions +from torchao.quantization.pt2e_flow.quantizer.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, + SharedQuantizationSpec, +) +from torchao.quantization.pt2e_flow.quantizer.utils import _get_module_name_filter +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer_utils import ( + QuantizationConfig, + get_bias_qspec, + get_input_act_qspec, + get_output_act_qspec, + get_weight_qspec, +) + +FilterFn: TypeAlias = Callable[[list[Node]], bool] + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + +__all__ = [ + "X86InductorQuantizer", + "get_default_x86_inductor_quantization_config", + "get_x86_inductor_linear_dynamic_fp16_config", +] + + +@dataclass +class _X86InductorQuantizationAnnotation(QuantizationAnnotation): + # _is_output_of_quantized_pattern: + # * Node as output node of a fusion pattern. + # * The fusion pattern supports int8 data type. + # * The fusion pattern has inputs annotated to insert observer. + # * The quantization_config is not `None`. + _is_output_of_quantized_pattern: bool = False + + +# Operators that: +# 1. Operators are optimized to run with int8 when int8 input provided. +# 2. Operators do not support int8 input and produce fp32 output. +int8_in_int8_out_ops: set = { + torch.ops.aten.max_pool2d.default, + torch.ops.aten.cat.default, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.flatten.using_ints, +} + +# Operators that support the int8 data type for quantization config propagation. +# A superset of int8_in_int8_out_ops incorporating additional operators. +propagation_quantizable_ops = int8_in_int8_out_ops + +# Operators support the int8 data type +# and recipe is configured by default in X86InductorQuantizer. +default_quantizable_ops = propagation_quantizable_ops | { + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, +} + +# A superset of default_quantizable_ops includes operators support the int8 data type +# but not enabled by default recipe of X86InductorQuantizer. +quantizable_ops = default_quantizable_ops | { + torch.ops.aten.matmul.default, +} + +QUANT_ANNOTATION_KEY = "quantization_annotation" + + +def _skip_annotate(nodes: list[Node], filter_fn: Optional[FilterFn] = None) -> bool: + """Determine whether to skip annotation for a list of nodes.""" + + # 1) Skip annotate if any node is already annotated + if _is_any_annotated(nodes): + return True + + # 2) Proceed annotate if a) a filter function is provided + # and b) the given nodes list passes the filter function check. + if filter_fn and filter_fn(nodes): + return False + + return True + + +def _create_module_name_filter(module_name: str) -> FilterFn: + """Create a filter function for a given module name. + + The filter function takes a list of nodes (as determined by the annotate function) + and return True if *all* nodes come from the specified module name, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> module_name_filter = _create_module_name_filter_inner("sub") + >> print(module_name_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and from "sub". + """ + + filter_fn = _get_module_name_filter(module_name) + + def check_all_nodes_from_module(nodes: list[Node]) -> bool: + all_nodes_from_module_name: bool = all(filter_fn(n) for n in nodes) + return all_nodes_from_module_name + + return check_all_nodes_from_module + + +def _create_operator_type_filter( + operator_type: Callable, +) -> FilterFn: + """Create a filter function for a given operator type. + + The filter function takes a list of nodes and returns True if it contains + exactly one node with the specified operator type, False otherwise. + + For example: + linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1` + relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1` + + >> operator_type_filter = _create_operator_type_filter(torch.ops.aten.linear.default) + >> print(operator_type_filter([relu, linear_1])) + # True # These two nodes are determined by `_annotate_linear_unary` function and the second node is `linear`. + """ + + def operator_type_filter(nodes: list[Node]): + num_nodes_with_operator_type = sum( + node.target == operator_type for node in nodes + ) + if num_nodes_with_operator_type > 1: + raise NotImplementedError( + f"Several nodes within a single pattern are {operator_type}." + ) + return num_nodes_with_operator_type == 1 + + return operator_type_filter + + +def _global_config_filter(nodes: list[Node]) -> bool: + """Filter function for global configuration. + + This filter function takes a list of nodes and returns True if there is exactly one node + in the list that is a default quantizable operation, False otherwise. + """ + num_nodes_in_default_quantizable_ops = sum( + node.target in default_quantizable_ops for node in nodes + ) + if num_nodes_in_default_quantizable_ops > 1: + raise NotImplementedError( + "Several nodes within a single pattern are default quantizable operations." + ) + return num_nodes_in_default_quantizable_ops == 1 + + +def _map_module_function_to_aten_operator_type(): + module_function_to_aten_operator: dict[Callable, torch._ops.OpOverloadPacket] = {} + map_list = ( + ([torch.nn.Conv2d, F.conv2d], torch.ops.aten.conv2d.default), + ([torch.nn.Linear, F.linear], torch.ops.aten.linear.default), + ([torch.nn.MaxPool2d, F.max_pool2d], torch.ops.aten.max_pool2d.default), + ( + [ + torch.cat, + ], + torch.ops.aten.cat.default, + ), + ([torch.nn.AvgPool2d, F.avg_pool2d], torch.ops.aten.avg_pool2d.default), + ( + [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], + torch.ops.aten.adaptive_avg_pool2d.default, + ), + ( + [ + torch.flatten, + ], + torch.ops.aten.flatten.using_ints, + ), + ( + [ + torch.matmul, + ], + torch.ops.aten.matmul.default, + ), + ) + for map_item in map_list: + module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[arg-type, call-overload] + return module_function_to_aten_operator + + +def _mark_nodes_as_annotated(nodes: list[Node]): + for node in nodes: + if node is not None: + if QUANT_ANNOTATION_KEY not in node.meta: + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation() + node.meta[QUANT_ANNOTATION_KEY]._annotated = True + + +def _is_node_annotated(_node): + """ + return True if the node is annotated, otherwise return False + """ + return ( + QUANT_ANNOTATION_KEY in _node.meta + and _node.meta[QUANT_ANNOTATION_KEY]._annotated + ) + + +def _is_any_annotated(nodes: list[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False. + """ + return any(_is_node_annotated(node) for node in nodes) + + +def _is_all_annotated(nodes: list[Node]): + """ + Given a list of nodes (that represents an operator pattern), + return True if all of the node is annotated, otherwise return False. + """ + return all(_is_node_annotated(node) for node in nodes) + + +def _is_quantized_op_pt2e(node: torch.fx.Node): + """ + Used for pt2e flow to check if the node is a quantized node: + Case1: the node has been annotated as output node of a fusion pattern. + Case2: the node has been annotated as single quantized node. + """ + if not _is_any_annotated([node]): + # The node has not been annotated, directly return False + return False + quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None) + assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation) + return quantization_annotation._is_output_of_quantized_pattern + + +@functools.lru_cache +def get_default_x86_inductor_quantization_config( + is_qat: bool = False, + is_dynamic: bool = False, + reduce_range: bool = False, +): + """ + reduce_range is False by default. Set it to True on earlier CPUs without VNNI to avoid accuracy issue. + """ + extra_args: dict[str, Any] = {"eps": 2**-12} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + # Copy from x86 default qconfig from torch/ao/quantization/qconfig.py + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=127 if reduce_range else 255, + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + FusedMovingAvgObsFakeQuantize if is_qat else PerChannelMinMaxObserver + ) + + if is_qat: + # Only support per channel quant for now + extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + bias_quantization_spec = None # will use placeholder observer by default + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + return quantization_config + + +@functools.lru_cache +def get_x86_inductor_linear_dynamic_fp16_config(): + """ + For linear_dynamic_fp16. The name may be confusing. + The op's behavior is fp32_input * (fp16_weight -> to_fp32) -> fp32_output. + """ + weight_quantization_spec = QuantizationSpec( + dtype=torch.float16, + observer_or_fake_quant_ctr=PlaceholderObserver, + ) + quantization_config = QuantizationConfig( + None, # input_quantization_spec + None, # output_quantization_spec + weight_quantization_spec, + None, # bias_quantization_spec + ) + return quantization_config + + +def _annotate_nodes_not_quantize(nodes: Union[Node, list[Node]]) -> None: + """Annotate nodes to exclude them from quantization (their `quantization_config` is `None`).""" + if not isinstance(nodes, list): + nodes = [nodes] + for node in nodes: + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True + ) + + +def _config_checker(method: Callable) -> Callable: + @functools.wraps(method) + def wrapper( + quantizer: "X86InductorQuantizer", + name: Any, + quantization_config: Optional["QuantizationConfig"], + ) -> "X86InductorQuantizer": + if quantizer._need_skip_config(quantization_config): + warnings.warn( + f"Skip the quantization config for {name}.", + ) + return quantizer + return method(quantizer, name, quantization_config) + + return wrapper + + +@dataclass +class _CurrentQuantizationMode: + r"""Configuration defining the current quantization mode for the quantizer. + + All possible current quantization modes are listed below: + ---------------------------------------------------------------------------------------------------------- + | dynamic_state + qat_state |--------------------------------------------------------------------------------------------- + | None | True | False + ---------------------------------------------------------------------------------------------------------- + None | quantizer does not receive a non-None `quantization_config` | \ | \ + False | quantizer will not do QAT | dynamic | static + True | quantizer will do QAT | QAT + dynamic | QAT + static + """ + + qat_state: Optional[bool] + dynamic_state: Optional[bool] + + +class X86InductorQuantizer(Quantizer): + module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type() + + def __init__(self) -> None: + super().__init__() + self.global_config: Optional[QuantizationConfig] = None + self.operator_type_qconfig: dict[ + torch._ops.OpOverloadPacket, Optional[QuantizationConfig] + ] = {} + self.module_name_qconfig: dict[str, Optional[QuantizationConfig]] = {} + + def _get_current_quantization_mode(self) -> _CurrentQuantizationMode: + """Retrieves the current quantization mode based on all configurations.""" + qat_state = None + dynamic_state = None + + # As we use `_need_skip_config` to skip all invalid configurations, + # we can safely assume that the all existing non-None configurations + # have the same quantization mode. + for qconfig in ( + list(self.module_name_qconfig.values()) + + list(self.operator_type_qconfig.values()) + + [self.global_config] + ): + if qconfig is not None: + # Query the `is_qat` state + if qat_state is None: + qat_state = qconfig.is_qat + else: + assert qat_state == qconfig.is_qat, ( + f"All non-None quantization configs should have the same `is_qat`," + f"but got {qat_state} and {qconfig.is_qat}." + ) + # Query the `is_dynamic` state + input_activation_spec = qconfig.input_activation + if input_activation_spec is not None: + if dynamic_state is None: + dynamic_state = input_activation_spec.is_dynamic + else: + assert dynamic_state == input_activation_spec.is_dynamic, ( + f"All non-None `input_activation_spec` should have the same `is_dynamic`," + f"but got {dynamic_state} and {input_activation_spec.is_dynamic}." + ) + return _CurrentQuantizationMode( + qat_state=qat_state, dynamic_state=dynamic_state + ) + + def _need_skip_config( + self, quantization_config: Optional[QuantizationConfig] + ) -> bool: + """Check if the provided quantization config is valid for X86InductorQuantizer. + + Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported. + To avoid such a mix, we compare the incoming configuration with current configuration status. + Refer the `_CurrentQuantizationMode` definition for all possible modes. + """ + if quantization_config is None: + return False + + need_skip = False + current_mode = self._get_current_quantization_mode() + if ( + current_mode.qat_state is not None + and current_mode.qat_state != quantization_config.is_qat + ): + warnings.warn("Mixed QAT and Non-QAT quantization config is not supported.") + need_skip = True + if current_mode.dynamic_state is not None: + input_activation_spec = quantization_config.input_activation + if ( + input_activation_spec is not None + and current_mode.dynamic_state != input_activation_spec.is_dynamic + ): + warnings.warn( + "Mixed dynamic and static quantization config is not supported." + ) + need_skip = True + return need_skip + + def set_global(self, quantization_config: QuantizationConfig): + if self._need_skip_config(quantization_config): + warnings.warn("Skip the global quantization config.") + return self + self.global_config = quantization_config + return self + + def get_global_quantization_config(self): + if not isinstance(self.global_config, QuantizationConfig): + warnings.warn( + "The global_config for X86InductorQuantizer is currently invalid. \ + Please ensure that you use set_global to establish the global quantization configuration." + ) + return self.global_config + + @_config_checker + def set_function_type_qconfig( + self, + function_type: Callable, + quantization_config: Optional[QuantizationConfig], + ) -> "X86InductorQuantizer": + if function_type in X86InductorQuantizer.module_function_to_aten_operator_type: + self._set_aten_operator_qconfig( + X86InductorQuantizer.module_function_to_aten_operator_type[ + function_type + ], + quantization_config, + ) + else: + warnings.warn( + f"function: Unable to customize quantization config for {function_type} by X86InductorQuantizer." + ) + return self + + @_config_checker + def set_module_type_qconfig( + self, + module_type: torch.nn.Module, + quantization_config: Optional[QuantizationConfig], + ) -> "X86InductorQuantizer": + if module_type in X86InductorQuantizer.module_function_to_aten_operator_type: + self._set_aten_operator_qconfig( + X86InductorQuantizer.module_function_to_aten_operator_type[module_type], + quantization_config, + ) + else: + warnings.warn( + f"Module: Unable to customize quantization config for {module_type} by X86InductorQuantizer." + ) + return self + + @_config_checker + def set_module_name_qconfig( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + + The supported operators include `quantizable_ops` and `propagation_quantizable_ops`. + """ + self.module_name_qconfig[module_name] = quantization_config + return self + + def _set_aten_operator_qconfig( + self, + operator_type: torch._ops.OpOverloadPacket, + quantization_config: Optional[QuantizationConfig], + ) -> "X86InductorQuantizer": + if operator_type in quantizable_ops: + self.operator_type_qconfig[operator_type] = quantization_config + else: + warnings.warn( + f"operator: Unable to quantize {operator} by X86InductorQuantizer." + ) + return self + + def _annotate_conv_node_helper( + self, + conv_node: torch.fx.Node, + annotate_output: bool, + quantization_config: Optional[QuantizationConfig], + ) -> None: + """Helper function to annotate the conv node""" + if quantization_config is None: + _annotate_nodes_not_quantize(conv_node) + return + input_qspec_map = {} + input_node = conv_node.args[0] + assert isinstance(input_node, Node) + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + weight_node = conv_node.args[1] + assert isinstance(weight_node, Node) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + bias_node = None if len(conv_node.args) == 2 else conv_node.args[2] + if isinstance(bias_node, Node): + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + if annotate_output: + conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + + def _annotate_linear_node_helper( + self, + linear_node: torch.fx.Node, + annotate_output: bool, + quantization_config: Optional[QuantizationConfig], + ) -> None: + """Helper function to annotate the linear node""" + if quantization_config is None: + _annotate_nodes_not_quantize(linear_node) + return + input_qspec_map = {} + assert linear_node.target in (torch.ops.aten.linear.default,) + has_bias = len(linear_node.args) == 3 + input_index = 0 + weight_index = 1 + bias_index = 2 + + input_node = linear_node.args[input_index] + assert isinstance(input_node, Node) + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + + weight_node = linear_node.args[weight_index] + assert isinstance(weight_node, Node) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + + bias_node = linear_node.args[bias_index] if has_bias else None + if isinstance(bias_node, Node): + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + + if annotate_output: + linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, _annotated=True + ) + + def _get_output_nodes_of_partitions( + self, + partition_list: list[SourcePartition], + ) -> list[torch.fx.Node]: + """Helper function to get the output node list from partition list""" + output_node_list = [] + for partition in partition_list: + if len(partition.output_nodes) > 1: + raise ValueError("Input partition has more than one output node") + output_node = partition.output_nodes[0] + assert isinstance(output_node, Node) + output_node_list.append(output_node) + if len(output_node_list) != len(partition_list): + raise ValueError( + "length of output_node_list should equal to length of partition_list" + ) + return output_node_list + + def _get_input_idx_for_binary_node( + self, + conv_gemm_node: torch.fx.Node, + binary_node: torch.fx.Node, + ): + """Helper function to check conv_gemm and extra input node index + for binary node fused with conv_gemm. + """ + conv_gemm_node_idx = None + extra_input_node_idx = None + if (binary_node.args[0].op == "call_function") and ( # type: ignore[union-attr] + binary_node.args[0] == conv_gemm_node + ): + conv_gemm_node_idx = 0 + extra_input_node_idx = 1 + elif (binary_node.args[1].op == "call_function") and ( # type: ignore[union-attr] + binary_node.args[1] == conv_gemm_node + ): + conv_gemm_node_idx = 1 + extra_input_node_idx = 0 + extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index] + assert isinstance(extra_input_node, Node) + return conv_gemm_node_idx, extra_input_node_idx + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Annotate the given model with quantization configurations. + + Annotation contracts: + 1. Annotate each node according to the user's qconfig in the following order: + `module_name_qconfig`, `operator_type_qconfig`, and `global_config`. + 2. Avoid re-annotating nodes already annotated in prior stages. For example, + if `linear1` has been annotated by `module_name_qconfig`, it won't be annotated again + during the processing of the 'operator_type_qconfig' or 'global_config'. + 3. For config is `None`, the node will be annotated with `_X86InductorQuantizationAnnotation(_annotated=True)`. + + For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created. + This filter function checks if the node is marked by current stage and not annotated by the previous stage. + """ + for module_name, quantization_config in self.module_name_qconfig.items(): + self._annotate_with_config( + model, quantization_config, _create_module_name_filter(module_name) + ) + + for operator_type, quantization_config in self.operator_type_qconfig.items(): + self._annotate_with_config( + model, quantization_config, _create_operator_type_filter(operator_type) + ) + + if self.global_config: + self._annotate_with_config( + model, + self.global_config, + _global_config_filter, + ) + + # Once we've annotated the model with quantization configurations, we also need to annotate + # the output of quantizable operations. For example, if we annotated `maxpool2d` to quantize its inputs, + # we will quantize its output accordingly. This enables us to fuse the dq-operator-q into a quantized op. + # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ + # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 + + self._annotate_output_for_int8_in_int8_out_pattern_entry(model) + + return model + + def _annotate_with_config( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: FilterFn, + ) -> None: + """Annotate the model with the given quantization configuration. + + High-level description of quantization recipe for X86 Inductor Backend: + Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. + Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model + from start to the end. If a pattern supports computation with int8 data type and inputs connected to + quantized patterns, annotate its inputs as quantized pattern. + """ + + # Step1: Recipe of fusion patterns like conv/linear. + self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn) + self._annotate_matmul(model, quantization_config, filter_fn) + + # Step2: Recipe to propagate annotation for patterns beside conv/linear. + # Go through all the nodes from start to end. + # Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/ + # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538 + + self._annotate_propagation_quantizable_pattern_entry( + model, quantization_config, filter_fn + ) + + def _annotate_qat_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + # Annotate QAT Specific patterns + self._annotate_qat_conv2d_bn_binary_unary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn_binary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn_unary(model, quantization_config, filter_fn) + self._annotate_qat_conv2d_bn(model, quantization_config, filter_fn) + + def _annotate_qat_conv2d_bn_binary_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] + ) + for fused_partition in fused_partitions: + ( + conv_partition, + bn_partition, + binary_partition, + unary_partition, + ) = fused_partition + + ( + conv_node, + bn_output_node, + binary_node, + unary_node, + ) = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition, binary_partition, unary_partition] + ) + if len(bn_output_node.users) != 1: + # Conv BN pattern should only has 1 user. + continue + ( + bn_output_node_idx, + extra_input_node_idx, + ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node) + if (bn_output_node_idx is None) or (extra_input_node_idx is None): + continue + if bn_output_node != binary_node.args[bn_output_node_idx]: + raise ValueError(f"{bn_output_node} doesn't match input of binary node") + extra_input_node = binary_node.args[extra_input_node_idx] + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _skip_annotate( + [unary_node, binary_node, bn_output_node, conv_node], filter_fn + ): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + + if quantization_config is not None: + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) + ) + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + else: + _annotate_nodes_not_quantize([binary_node, unary_node]) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + nodes_to_mark_annotated.extend(list(binary_partition.nodes)) + nodes_to_mark_annotated.extend(list(unary_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_qat_conv2d_bn_binary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add] + ) + for fused_partition in fused_partitions: + conv_partition, bn_partition, binary_partition = fused_partition + ( + conv_node, + bn_output_node, + binary_node, + ) = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition, binary_partition] + ) + if len(bn_output_node.users) != 1: + # Conv BN pattern should only has 1 user. + continue + ( + bn_output_node_idx, + extra_input_node_idx, + ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node) + if (bn_output_node_idx is None) or (extra_input_node_idx is None): + continue + if bn_output_node != binary_node.args[bn_output_node_idx]: + raise ValueError(f"{bn_output_node} doesn't match input of binary node") + + extra_input_node = binary_node.args[extra_input_node_idx] + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _skip_annotate([binary_node, bn_output_node, conv_node], filter_fn): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + + if quantization_config is not None: + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + else: + _annotate_nodes_not_quantize(binary_node) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + nodes_to_mark_annotated.extend(list(binary_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_qat_conv2d_bn_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + fused_partitions = [] + unary_patterns = [ + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardswish], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.SiLU], + ] + for unary_pattern in unary_patterns: + partitions = find_sequential_partitions(gm, unary_pattern) + if partitions: + # Extend the fused_partitions if partitions is not empty + fused_partitions.extend(partitions) + + for fused_partition in fused_partitions: + conv_partition, bn_partition, unary_partition = fused_partition + ( + conv_node, + bn_output_node, + unary_node, + ) = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition, unary_partition] + ) + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _skip_annotate([unary_node, bn_output_node, conv_node], filter_fn): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + if quantization_config is not None: + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + else: + _annotate_nodes_not_quantize(unary_node) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + nodes_to_mark_annotated.extend(list(unary_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_qat_conv2d_bn( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d] + ) + for fused_partition in fused_partitions: + conv_partition, bn_partition = fused_partition + conv_node, bn_output_node = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition] + ) + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _skip_annotate([bn_output_node, conv_node], filter_fn): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + if quantization_config is not None: + bn_output_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + else: + _annotate_nodes_not_quantize(bn_output_node) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + if (quantization_config is None) or (quantization_config.is_qat): + # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat + self._annotate_qat_conv2d_fusion_pattern( + model, quantization_config, filter_fn + ) + self._annotate_conv2d_binary_unary(model, quantization_config, filter_fn) + self._annotate_conv2d_binary(model, quantization_config, filter_fn) + self._annotate_conv2d_unary(model, quantization_config, filter_fn) + self._annotate_conv2d(model, quantization_config, filter_fn) + + def _annotate_linear_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + self._annotate_linear_binary_unary(model, quantization_config, filter_fn) + self._annotate_linear_unary(model, quantization_config, filter_fn) + self._annotate_linear(model, quantization_config, filter_fn) + + def _annotate_matmul( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + for node in model.graph.nodes: + if node.target != torch.ops.aten.matmul.default: + continue + if _skip_annotate([node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize(node) + continue + + input_qspec_map = {} + matmul_node = node + for input_node in matmul_node.args: + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + matmul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d_binary_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + # Conv2d + add + unary op + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, operator.add, torch.nn.ReLU] + ) + for fused_partition in fused_partitions: + conv_partition, binary_partition, unary_partition = fused_partition + conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions( + [conv_partition, binary_partition, unary_partition] + ) + if len(conv_node.users) != 1: + # Conv Node should only has 1 user node + continue + conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( + conv_node, binary_node + ) + if (conv_node_idx is None) or (extra_input_node_idx is None): + continue + if conv_node != binary_node.args[conv_node_idx]: + raise ValueError(f"{conv_node} doesn't match input of binary node") + extra_input_node = binary_node.args[extra_input_node_idx] + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + # No conv node found to be fused with add + continue + if _skip_annotate([unary_node, binary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, binary_node, unary_node]) + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d_binary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + # Conv2d + add + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, operator.add] + ) + for fused_partition in fused_partitions: + conv_partition, binary_partition = fused_partition + conv_node, binary_node = self._get_output_nodes_of_partitions( + [conv_partition, binary_partition] + ) + if len(conv_node.users) != 1: + # Conv Node should only has 1 user node + continue + conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( + conv_node, binary_node + ) + if (conv_node_idx is None) or (extra_input_node_idx is None): + continue + if conv_node != binary_node.args[conv_node_idx]: + raise ValueError(f"{conv_node} doesn't match input of binary node") + extra_input_node = binary_node.args[extra_input_node_idx] + assert isinstance(conv_node, Node) + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + # No conv node found to be fused with add + continue + if _skip_annotate([binary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, binary_node]) + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + fused_partitions = [] + unary_patterns = [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, torch.nn.Hardtanh], + [torch.nn.Conv2d, torch.nn.Hardswish], + [torch.nn.Conv2d, torch.nn.ReLU6], + [torch.nn.Conv2d, torch.nn.SiLU], + ] + for unary_pattern in unary_patterns: + partitions = find_sequential_partitions(gm, unary_pattern) + if partitions: + # Extend the fused_partitions if partitions is not empty + fused_partitions.extend(partitions) + + for fused_partition in fused_partitions: + conv_partition, unary_partition = fused_partition + conv_node, unary_node = self._get_output_nodes_of_partitions( + [conv_partition, unary_partition] + ) + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + if _skip_annotate([unary_node, conv_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([conv_node, unary_node]) + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + conv_partitions = get_source_partitions( + gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] + ) + conv_partitions = list(itertools.chain.from_iterable(conv_partitions.values())) + for conv_partition in conv_partitions: + if len(conv_partition.output_nodes) > 1: + raise ValueError("conv partition has more than one output node") + conv_node = conv_partition.output_nodes[0] + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + raise ValueError(f"{conv_node} is not an aten conv2d operator") + # skip annotation if it is already annotated + if _skip_annotate([conv_node], filter_fn): + continue + self._annotate_conv_node_helper(conv_node, True, quantization_config) + + def _annotate_maxpool2d( + self, + node: Node, + quantization_config: Optional[QuantizationConfig], + ) -> None: + if node.target is not torch.ops.aten.max_pool2d.default: + return + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + + maxpool_node = node + if _is_any_annotated( + [ + maxpool_node, + ] + ): + return + + input_node = maxpool_node.args[0] + assert isinstance(input_node, Node) + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_cat( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + cat_node = node + input_nodes = cat_node.args[0] + assert isinstance(input_nodes, Sequence) + first_input_node = input_nodes[0] + input_qspec_map = {} + assert isinstance(first_input_node, Node) + assert isinstance(cat_node, Node) + input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config) + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, cat_node) + ) + + for input_node in input_nodes[1:]: + if input_node not in input_qspec_map: + # There has the case of cat same nodes: torch.cat([input0, input0], 1) + assert isinstance(input_node, Node) + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + + cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_propagation_quantizable_pattern_entry( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + for node in gm.graph.nodes: + self._annotate_propagation_quantizable_pattern( + node, quantization_config, filter_fn + ) + + def _annotate_propagation_quantizable_pattern( + self, node: Node, quantization_config, filter_fn + ) -> None: + # Propagate annotation to quantizable patterns. + if ( + (node.target in propagation_quantizable_ops) + and (not _is_any_annotated([node])) + and (node.op == "call_function") + ): + + def is_all_inputs_connected_to_quantized_op(input_nodes): + # Ensure all the inputs connect to fusion pattern or quantized node + for input_node in input_nodes: + if not _is_quantized_op_pt2e(input_node): + return False + return True + + if _skip_annotate([node], filter_fn): + return + + if quantization_config is None: + _annotate_nodes_not_quantize(node) + return + + if node.target is torch.ops.aten.max_pool2d.default: + # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not + input_nodes_to_check = [node.all_input_nodes[0]] + if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): + if quantization_config is not None: + warnings.warn( + f"The input of maxpool2d is not quantized, skip annotate maxpool2d with config {quantization_config}." + ) + return + + self._annotate_maxpool2d(node, quantization_config) + return + elif node.target is torch.ops.aten.cat.default: + input_nodes_to_check = node.all_input_nodes + if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): + return + self._annotate_cat(node, quantization_config) + elif ( + node.target is torch.ops.aten.flatten.using_ints + and len(node.users) > 0 + and not any( + user.target in quantizable_ops for user in node.users.keys() + ) + ): + # Recipe of flatten: check if any users of flatten node are quantizable ops or not + return + else: + input_node = node.all_input_nodes[0] + if not is_all_inputs_connected_to_quantized_op( + [ + input_node, + ] + ): + return + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + return + + def _annotate_output_share_observer_as_input( + self, input_node: Node, source_node: Node + ): + source_node_quantization_annotation = ( + source_node.meta[QUANT_ANNOTATION_KEY] + if QUANT_ANNOTATION_KEY in source_node.meta + else None + ) + if ( + source_node_quantization_annotation + and source_node_quantization_annotation._is_output_of_quantized_pattern + ): + edge_or_node = (input_node, source_node) + source_node_quantization_annotation.output_qspec = SharedQuantizationSpec( + edge_or_node + ) + return + + def _annotate_output_for_int8_in_int8_out_pattern_entry( + self, + model: torch.fx.GraphModule, + ): + for node in model.graph.nodes: + self._annotate_output_for_int8_in_int8_out_pattern(node) + + def _annotate_output_for_int8_in_int8_out_pattern( + self, + node: Node, + ) -> None: + r""" + Check and insert observer at output of node in int8_in_int8_out_ops if needed. + Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/ + 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495 + """ + edge_or_node: tuple[Node, Node] + if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): + if node.target == torch.ops.aten.max_pool2d.default: + maxpool_node = node + if not _is_all_annotated( + [ + maxpool_node, + ] + ): + return + + # Get the quantization_annotation from getitem_node + maxpool_node_quantization_annotation = ( + maxpool_node.meta[QUANT_ANNOTATION_KEY] + if QUANT_ANNOTATION_KEY in maxpool_node.meta + else None + ) + if ( + maxpool_node_quantization_annotation + and maxpool_node_quantization_annotation._is_output_of_quantized_pattern + ): + # Annotate the output_qspec of getitem_node + input_act = maxpool_node.args[0] + assert isinstance(input_act, Node) + assert isinstance(maxpool_node, Node) + edge_or_node = (input_act, maxpool_node) + maxpool_node_quantization_annotation.output_qspec = ( + SharedQuantizationSpec(edge_or_node) + ) + else: + input_node = node.all_input_nodes[0] + self._annotate_output_share_observer_as_input(input_node, node) + return + + def _annotate_linear( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + linear_partitions = get_source_partitions( + gm.graph, [torch.nn.Linear, torch.nn.functional.linear] + ) + linear_partitions = list( + itertools.chain.from_iterable(linear_partitions.values()) + ) + for partition in linear_partitions: + if len(partition.output_nodes) > 1: + raise ValueError( + "Linear partition cannot have more than one output node" + ) + linear_node = partition.output_nodes[0] + if linear_node.op != "call_function" or linear_node.target not in ( + torch.ops.aten.linear.default, + ): + raise ValueError(f"{linear_node} is not an aten linear operator") + # skip annotation if it is already annotated + if _skip_annotate([linear_node], filter_fn): + continue + self._annotate_linear_node_helper(linear_node, True, quantization_config) + + def _annotate_linear_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + postop_list = [ + torch.nn.ReLU, + torch.nn.LeakyReLU, + torch.nn.Tanh, + torch.nn.GELU, + ] + fused_partitions: list[tuple] = [] + for postop in postop_list: + fused_partitions = fused_partitions + find_sequential_partitions( + gm, [torch.nn.Linear, postop] + ) + for fused_partition in fused_partitions: + linear_partition, unary_partition = fused_partition + linear_node, unary_node = self._get_output_nodes_of_partitions( + [linear_partition, unary_partition] + ) + if linear_node.op != "call_function" or linear_node.target not in ( + torch.ops.aten.linear.default, + ): + continue + if _skip_annotate([unary_node, linear_node], filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize([linear_node, unary_node]) + continue + + self._annotate_linear_node_helper(linear_node, False, quantization_config) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_linear_binary_unary( + self, + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ) -> None: + # linear + binary_op + (optional) unary op + binary_op_list = [operator.add] + unary_op_list = [torch.nn.ReLU, None] + combinations = itertools.product(binary_op_list, unary_op_list) + for binary_op, unary_op in combinations: + has_unary = unary_op is not None + seq_partition = [torch.nn.Linear, binary_op] + if has_unary: + seq_partition.append(unary_op) + fused_partitions = find_sequential_partitions(gm, seq_partition) + for fused_partition in fused_partitions: + unary_partition, unary_node = None, None + if has_unary: + ( + linear_partition, + binary_partition, + unary_partition, + ) = fused_partition + ( + linear_node, + binary_node, + unary_node, + ) = self._get_output_nodes_of_partitions( + [linear_partition, binary_partition, unary_partition] + ) + else: + linear_partition, binary_partition = fused_partition + linear_node, binary_node = self._get_output_nodes_of_partitions( + [linear_partition, binary_partition] + ) + if len(linear_node.users) != 1: + # Linear Node should only has 1 user node + continue + ( + linear_node_idx, + extra_input_node_idx, + ) = self._get_input_idx_for_binary_node(linear_node, binary_node) + if (linear_node_idx is None) or (extra_input_node_idx is None): + continue + if linear_node != binary_node.args[linear_node_idx]: + raise ValueError( + f"{linear_node} doesn't match input of binary node" + ) + assert isinstance(linear_node, Node) + if ( + linear_node.op != "call_function" + or linear_node.target != torch.ops.aten.linear.default + ): + # No linear node found to be fused with add + continue + node_list = ( + [binary_node, linear_node] + if unary_node is None + else [unary_node, binary_node, linear_node] + ) + if _skip_annotate(node_list, filter_fn): + continue + + if quantization_config is None: + _annotate_nodes_not_quantize(node_list) + continue + + self._annotate_linear_node_helper( + linear_node, False, quantization_config + ) + # We don't insert q-dq before the binary input node due to accuracy issues + binary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + input_qspec_map={}, + _annotated=True, + _is_output_of_quantized_pattern=(not has_unary), + ) + ) + if unary_node is not None: + unary_node.meta[QUANT_ANNOTATION_KEY] = ( + _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass diff --git a/torchao/quantization/pt2e_flow/quantizer/xnnpack_quantizer.py b/torchao/quantization/pt2e_flow/quantizer/xnnpack_quantizer.py new file mode 100644 index 0000000000..2c76804b3d --- /dev/null +++ b/torchao/quantization/pt2e_flow/quantizer/xnnpack_quantizer.py @@ -0,0 +1,447 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import copy +import functools +import warnings +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch +import torch._dynamo as torchdynamo +import torch.nn.functional as F +from torch.fx._compatibility import compatibility + +from torchao.quantization.pt2e_flow.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torchao.quantization.pt2e_flow.observer import ( + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, +) +from torchao.quantization.pt2e_flow.quantizer import QuantizationSpec, Quantizer +from torchao.quantization.pt2e_flow.quantizer.utils import _get_module_name_filter +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer_utils import ( + OP_TO_ANNOTATOR, + OperatorConfig, + OperatorPatternType, + QuantizationConfig, + _convert_scalars_to_attrs, + propagate_annotation, +) + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + from torch.fx import Node + + +__all__ = [ + "XNNPACKQuantizer", + "get_symmetric_quantization_config", +] + + +def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph: + gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs) + gm.graph.eliminate_dead_code() + return gm.graph + + +def _get_linear_patterns(input_size: list[int]): + in_channels = input_size[-1] + out_channels = 8 # hard coding but this should not matter + weight = torch.ones((out_channels, in_channels)) + bias = torch.ones((out_channels,)) + act = torch.ones(input_size) + + def linear_op(act, weight, bias=None): + return F.linear(act, weight, bias) + + pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias)) + pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight)) + return [pattern_w_bias, pattern_wo_bias] + + +def _supported_symmetric_quantized_operators() -> dict[str, list[OperatorPatternType]]: + supported_operators: dict[str, list[OperatorPatternType]] = { + # Both conv and linear should be able to handle relu + hardtanh fusion since + # those are clamp ops + "conv2d": [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, F.relu], + [F.conv2d, torch.nn.ReLU], + [F.conv2d, F.relu], + ], + "linear": [[torch.nn.Linear], [F.linear]], + "add": [[torch.add]], + "adaptive_avg_pool2d": [ + [torch.nn.AdaptiveAvgPool2d], + [F.adaptive_avg_pool2d], + ], + } + return copy.deepcopy(supported_operators) + + +def _get_supported_symmetric_config_and_operators() -> list[OperatorConfig]: + supported_config_and_operators: list[OperatorConfig] = [] + for quantization_config in [ + get_symmetric_quantization_config(), + get_symmetric_quantization_config(is_qat=True), + get_symmetric_quantization_config(is_per_channel=True), + get_symmetric_quantization_config(is_per_channel=True, is_qat=True), + ]: + ops = _supported_symmetric_quantized_operators() + supported_config_and_operators.extend( + OperatorConfig(quantization_config, pattern_list) + for pattern_list in ops.values() + ) + return copy.deepcopy(supported_config_and_operators) + + +@functools.lru_cache +def get_symmetric_quantization_config( + is_per_channel: bool = False, + is_qat: bool = False, + is_dynamic: bool = False, + act_qmin: int = -128, + act_qmax: int = 127, + weight_qmin: int = -127, + weight_qmax: int = 127, +): + extra_args: dict[str, Any] = {"eps": 2**-12} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=act_qmin, + quant_max=act_qmax, + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args, + ), + ) + weight_qscheme = ( + torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric + ) + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + MinMaxObserver + ) + if is_qat: + # TODO: qat + per channel? + weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize + elif is_per_channel: + weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + + extra_args: dict[str, Any] = {"eps": 2**-12} + if is_qat: + if weight_qscheme == torch.per_tensor_symmetric: + extra_args["observer"] = MovingAverageMinMaxObserver + else: + extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=weight_qmin, + quant_max=weight_qmax, + qscheme=weight_qscheme, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + bias_quantization_spec = None + if is_dynamic: + quantization_config = QuantizationConfig( + act_quantization_spec, + None, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + else: + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + return quantization_config + + +def _get_supported_config_and_operators() -> list[OperatorConfig]: + return _get_supported_symmetric_config_and_operators() + + +def _get_module_type_filter(tp: Callable): + """Get the module_type_filter function for a given module type, the filter accepts + a node and checks if the node comes from a module that has certain module type + + For example: + node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear + + + >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule + >> print(module_type_filter(node)) + True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) + """ + + tp_str = tp.__module__ + "." + tp.__qualname__ + + def module_type_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + nn_module_stack = n.meta.get("nn_module_stack", {}) + types = [] + for _, t in nn_module_stack.values(): + # export() returns str, but older APIs (e.g. capture_pre_autograd_graph) + # return type. Handle both cases. + if isinstance(t, type): + t = t.__module__ + "." + t.__qualname__ + types.append(t) + return tp_str in types + + return module_type_filter + + +def _get_not_module_type_or_name_filter( + tp_list: list[Callable], module_name_list: list[str] +) -> Callable[[Node], bool]: + module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] + module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + + def not_module_type_or_name_filter(n: Node) -> bool: + return not any(f(n) for f in module_type_filters + module_name_list_filters) + + return not_module_type_or_name_filter + + +@compatibility(is_backward_compatible=False) +class XNNPACKQuantizer(Quantizer): + """ + !!! DEPRECATED !!! + XNNPACKQuantizer is a marked as deprected. It will be removed in the future. + It has been moved to executorch.backends.xnnpack.quantizer.xnnpack_quantizer.XNNPACKQuantizer. + Please use the new quantizer instead. + """ + + supported_config_and_operators = _get_supported_config_and_operators() + STATIC_QAT_ONLY_OPS = [ + "conv_bn_relu", + "conv_bn", + "conv_transpose_bn_relu", + "conv_transpose_bn", + ] + + # static quantization ops (both PTQ and QAT) + # Preserve the order that fusions come before singular ops + STATIC_OPS = [ + "linear_relu", + "linear", + "conv_relu", + "conv", + "conv_transpose_relu", + "adaptive_avg_pool2d", + # TODO: move this to BoltNNQuantizer? + "gru_io_only", + "add_relu", + "add", + "mul_relu", + "mul", + "cat", + ] + + DYNAMIC_OPS = [ + "linear", + ] + + def __init__(self) -> None: + super().__init__() + warnings.warn(f"{self.__class__.__name__} is deprecated!") + self.global_config: Optional[QuantizationConfig] = None + self.operator_type_config: dict[ + torch._ops.OpOverloadPacket, Optional[QuantizationConfig] + ] = {} + self.module_type_config: dict[Callable, Optional[QuantizationConfig]] = {} + self.module_name_config: dict[str, Optional[QuantizationConfig]] = {} + + @classmethod + def get_supported_quantization_configs(cls) -> list[QuantizationConfig]: + op_configs: set[QuantizationConfig] = { + spec for spec, _ in cls.supported_config_and_operators + } + return list(op_configs) + + @classmethod + def get_supported_operator_for_quantization_config( + cls, quantization_config: Optional[QuantizationConfig] + ) -> list[OperatorPatternType]: + if quantization_config is None: + all_ops = [] + for _, ops in cls.supported_config_and_operators: + all_ops.extend(ops) + return all_ops + + for config, ops in cls.supported_config_and_operators: + # note: this assumes each entry in cls.supported_spec_and_operators + # corresponds to one spec, e.g. we don't have + # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] + # where the first and second entry have the same spec but did not + # merge the op list + if config == quantization_config: + return ops + return [] + + def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer: + self.global_config = quantization_config + return self + + def set_operator_type( + self, + operator_type: torch._ops.OpOverloadPacket, + quantization_config: QuantizationConfig, + ) -> XNNPACKQuantizer: + self.operator_type_config[operator_type] = quantization_config + return self + + def set_module_type( + self, module_type: Callable, quantization_config: QuantizationConfig + ): + """Set quantization_config for a submodule with type: `module_type`, for example: + quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator + patterns in the submodule with this module type with the given `quantization_config` + """ + self.module_type_config[module_type] = quantization_config + return self + + def set_module_name( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + """ + assert ( + quantization_config is not None + ), " quantization_config == None is not supported yet" + self.module_name_config[module_name] = quantization_config + return self + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + """Transforms scalar values to tensor attributes""" + return _convert_scalars_to_attrs(model) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + # hacked for handling dynamic linear quant. will fix later. + if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] + model = self._annotate_for_dynamic_quantization_config(model) + else: + model = self._annotate_for_static_quantization_config(model) + propagate_annotation(model) + return model + + def _annotate_all_static_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + if quantization_config.is_qat: + for op in self.STATIC_QAT_ONLY_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + for op in self.STATIC_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_all_dynamic_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + for op in self.DYNAMIC_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_for_static_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_config.keys()) + for module_name, config in self.module_name_config.items(): + self._annotate_all_static_patterns( + model, config, _get_module_name_filter(module_name) + ) + + tp_list = list(self.module_type_config.keys()) + for module_type, config in self.module_type_config.items(): + self._annotate_all_static_patterns( + model, config, _get_module_type_filter(module_type) + ) + + self._annotate_all_static_patterns( + model, + self.global_config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + ) + return model + + def _annotate_for_dynamic_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_config.keys()) + for module_name, config in self.module_name_config.items(): + self._annotate_all_dynamic_patterns( + model, config, _get_module_name_filter(module_name) + ) + + tp_list = list(self.module_type_config.keys()) + for module_type, config in self.module_type_config.items(): + self._annotate_all_dynamic_patterns( + model, config, _get_module_type_filter(module_type) + ) + + self._annotate_all_dynamic_patterns( + model, + self.global_config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + ) + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> list[OperatorConfig]: + return cls.supported_config_and_operators diff --git a/torchao/quantization/pt2e_flow/quantizer/xnnpack_quantizer_utils.py b/torchao/quantization/pt2e_flow/quantizer/xnnpack_quantizer_utils.py new file mode 100644 index 0000000000..0b32d2170e --- /dev/null +++ b/torchao/quantization/pt2e_flow/quantizer/xnnpack_quantizer_utils.py @@ -0,0 +1,1127 @@ +# mypy: allow-untyped-defs +import itertools +import typing +from dataclasses import dataclass +from typing import Callable, NamedTuple, Optional + +import torch +import torch.nn.functional as F +from torch._subclasses import FakeTensor +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix +from torch.fx import Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap, +) +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + +from torchao.quantization.pt2e_flow.pt2e.export_utils import _WrapperModule +from torchao.quantization.pt2e_flow.pt2e.utils import ( + _get_aten_graph_module_for_pattern, + _is_conv_node, + _is_conv_transpose_node, +) +from torchao.quantization.pt2e_flow.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + SharedQuantizationSpec, +) +from torchao.quantization.pt2e_flow.quantizer.utils import ( + _annotate_input_qspec_map, + _annotate_output_qspec, +) + +__all__ = [ + "OperatorConfig", + "OperatorPatternType", + "QuantizationConfig", + "get_input_act_qspec", + "get_output_act_qspec", + "get_weight_qspec", + "get_bias_qspec", + "OP_TO_ANNOTATOR", + "propagate_annotation", +] + + +# In the absence of better name, just winging it with QuantizationConfig +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + weight: Optional[QuantizationSpec] + bias: Optional[QuantizationSpec] + # TODO: remove, since we can use observer_or_fake_quant_ctr to express this + is_qat: bool = False + + +# Use Annotated because list[Callable].__module__ is read-only. +OperatorPatternType = typing.Annotated[list[Callable], None] +OperatorPatternType.__module__ = ( + "torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer_utils" +) + +AnnotatorType = Callable[ + [ + torch.fx.GraphModule, + Optional[QuantizationConfig], + Optional[Callable[[Node], bool]], + ], + Optional[list[list[Node]]], +] +OP_TO_ANNOTATOR: dict[str, AnnotatorType] = {} + + +def register_annotator(op: str) -> Callable[[AnnotatorType], None]: + def decorator(annotator: AnnotatorType) -> None: + OP_TO_ANNOTATOR[op] = annotator + + return decorator + + +class OperatorConfig(NamedTuple): + # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]] + # Basically we are mapping a quantization config to some list of patterns. + # a pattern is defined as a list of nn module, function or builtin function names + # e.g. [nn.Conv2d, torch.relu, torch.add] + # We have not resolved whether fusion can be considered internal details of the + # quantizer hence it does not need communication to user. + # Note this pattern is not really informative since it does not really + # tell us the graph structure resulting from the list of ops. + config: QuantizationConfig + operators: list[OperatorPatternType] + + +def _is_annotated(nodes: list[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False + """ + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + +def _mark_nodes_as_annotated(nodes: list[Node]): + for node in nodes: + if node is not None: + if "quantization_annotation" not in node.meta: + node.meta["quantization_annotation"] = QuantizationAnnotation() + node.meta["quantization_annotation"]._annotated = True + + +def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + if quantization_config.input_activation is None: + return None + quantization_spec: QuantizationSpec = quantization_config.input_activation + assert quantization_spec.qscheme in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ] + return quantization_spec + + +def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + if quantization_config.output_activation is None: + return None + quantization_spec: QuantizationSpec = quantization_config.output_activation + assert quantization_spec.qscheme in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ] + return quantization_spec + + +def get_weight_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + assert quantization_config is not None + if quantization_config.weight is None: + return None + quantization_spec: QuantizationSpec = quantization_config.weight + if quantization_spec.qscheme not in [ + torch.per_tensor_symmetric, + torch.per_channel_symmetric, + None, + ]: + raise ValueError( + f"Unsupported quantization_spec {quantization_spec} for weight" + ) + return quantization_spec + + +def get_bias_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + assert quantization_config is not None + if quantization_config.bias is None: + return None + quantization_spec: QuantizationSpec = quantization_config.bias + assert ( + quantization_spec.dtype == torch.float + ), "Only float dtype for bias is supported for bias right now" + return quantization_spec + + +@register_annotator("linear") +def _annotate_linear( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target != torch.ops.aten.linear.default: + continue + if filter_fn and not filter_fn(node): + continue + act_node = node.args[0] + weight_node = node.args[1] + bias_node = None + if len(node.args) > 2: + bias_node = node.args[2] + + if _is_annotated([node]) is False: # type: ignore[list-item] + _annotate_input_qspec_map( + node, + act_node, + input_act_qspec, + ) + _annotate_input_qspec_map( + node, + weight_node, + weight_qspec, + ) + nodes_to_mark_annotated = [node, weight_node] + if bias_node: + _annotate_input_qspec_map( + node, + bias_node, + bias_qspec, + ) + nodes_to_mark_annotated.append(bias_node) + _annotate_output_qspec(node, output_act_qspec) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + annotated_partitions.append(nodes_to_mark_annotated) + + return annotated_partitions + + +@register_annotator("linear_relu") +def _annotate_linear_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_linear_node = node.args[0] + if ( + not isinstance(maybe_linear_node, Node) + or maybe_linear_node.op != "call_function" + or maybe_linear_node.target != torch.ops.aten.linear.default + ): + continue + + linear_node = maybe_linear_node + if len(linear_node.users) > 1: + # if linear node has multiple users, then it can't be fused with relu + continue + + input_qspec_map = {} + input_act = linear_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = input_act_qspec + + weight = linear_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = weight_qspec + + # adding weight node to the partition as well + partition = [relu_node, linear_node, weight] + bias = linear_node.args[2] if len(linear_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = bias_qspec + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + linear_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv") +def _annotate_conv( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ]: + continue + conv_node = n + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + # adding weight node to the partition as well + partition = [conv_node, conv_node.args[1]] + + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=get_output_act_qspec(quantization_config), + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +def _do_annotate_conv_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + is_conv_transpose: bool = False, +): + annotated_partitions = [] + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = n + maybe_conv_node = n.args[0] + + is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node + if not isinstance(maybe_conv_node, Node) or not is_conv_node(maybe_conv_node): + continue + conv_node = maybe_conv_node + + if len(conv_node.users) > 1: + # relu shouldn't be fuseable to conv if there are other users + # of convolution + continue + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + # adding weight node to the partition as well + partition = [relu_node, conv_node, conv_node.args[1]] + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, _annotated=True + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv_relu") +def _annotate_conv_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + return _do_annotate_conv_relu( + gm, quantization_config, filter_fn, is_conv_transpose=False + ) + + +@register_annotator("conv_transpose_relu") +def _annotate_conv_transpose_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + return _do_annotate_conv_relu( + gm, quantization_config, filter_fn, is_conv_transpose=True + ) + + +@register_annotator("conv_bn") +def _annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """ + Find conv + batchnorm parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False) + + +@register_annotator("conv_bn_relu") +def _annotate_conv_bn_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """ + Find conv + batchnorm + relu parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True) + + +@register_annotator("conv_transpose_bn") +def _annotate_conv_transpose_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """ + Find conv_transpose + batchnorm parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn( + gm, quantization_config, filter_fn, has_relu=False, is_conv_transpose=True + ) + + +@register_annotator("conv_transpose_bn_relu") +def _annotate_conv_transpose_bn_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """ + Find conv_transpose + batchnorm + relu parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn( + gm, quantization_config, filter_fn, has_relu=True, is_conv_transpose=True + ) + + +def _do_annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]], + has_relu: bool, + is_conv_transpose: bool = False, +) -> list[list[Node]]: + """ + Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern, + return a list of annotated partitions. + + The output of the pattern must include a dictionary from string name to node + for the following names: "input", "conv", "weight", "bias", and "output". + """ + + # Example inputs for conv-bn1d patterns + _conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + # Example inputs for conv-bn2d patterns + _conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + + def get_pattern(conv_fn: Callable, relu_is_inplace: bool): + def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): + conv = conv_fn(x, conv_weight, conv_bias) + bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True) + if has_relu: + output = F.relu_(bn) if relu_is_inplace else F.relu(bn) + else: + output = bn + return output, { + "input": x, + "conv": conv, + "weight": conv_weight, + "bias": conv_bias, + "output": output, + } + + return _WrapperModule(_conv_bn) + + # Needed for matching, otherwise the matches gets filtered out due to unused + # nodes returned by batch norm + gm.graph.eliminate_dead_code() + gm.recompile() + + matches = [] + if is_conv_transpose: + combinations = [ + (F.conv_transpose1d, _conv1d_bn_example_inputs), + (F.conv_transpose2d, _conv2d_bn_example_inputs), + ] + else: + combinations = [ + (F.conv1d, _conv1d_bn_example_inputs), # type: ignore[list-item] + (F.conv2d, _conv2d_bn_example_inputs), # type: ignore[list-item] + ] + + # Add `is_cuda` and `relu_is_inplace` dimensions + combinations = itertools.product( # type: ignore[assignment] + combinations, + [True, False] if torch.cuda.is_available() else [False], # is_cuda + [True, False] if has_relu else [False], # relu_is_inplace + ) + + # Match against all conv dimensions and cuda variants + for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc] + pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type] + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) # type: ignore[has-type] + pattern.graph.eliminate_dead_code() + pattern.recompile() + matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True) + matches.extend(matcher.match(gm.graph)) + + # Annotate nodes returned in the matches + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + input_node = name_node_map["input"] + conv_node = name_node_map["conv"] + weight_node = name_node_map["weight"] + bias_node = name_node_map["bias"] + output_node = name_node_map["output"] + + # TODO: annotate the uses of input, weight, and bias separately instead + # of assuming they come from a single conv node. This is not possible today + # because input may have multiple users, and we can't rely on the conv node + # always being the first user. This was the case in models with skip + # connections like resnet18 + + # Validate conv args + if conv_node.args[0] is not input_node: + raise ValueError("Conv arg did not contain input node ", input_node) + if conv_node.args[1] is not weight_node: + raise ValueError("Conv arg did not contain weight node ", weight_node) + if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node: + raise ValueError("Conv arg did not contain bias node ", bias_node) + + # Skip if the partition is already annotated or is filtered out by the user + partition = [conv_node, weight_node] + if bias_node is not None: + partition.append(bias_node) + if _is_annotated(partition): + continue + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + # Annotate conv inputs and pattern output + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + if bias_node is not None: + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + output_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("gru_io_only") +def _annotate_gru_io_only( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn) + gru_partitions = list(itertools.chain.from_iterable(gru_partitions.values())) + annotated_partitions = [] + for gru_partition in gru_partitions: + annotated_partitions.append(gru_partition.nodes) + output_nodes = gru_partition.output_nodes + input_nodes = gru_partition.input_nodes + # skip annotation if it is already annotated + if _is_annotated(input_nodes + output_nodes): + continue + # inside each GRU partition, we should be able to annotate each linear + # subgraph + input_act = input_nodes[0] + input_act_user = next(iter(input_act.users.keys())) + assert isinstance(input_act, Node) + assert isinstance(input_act_user, Node) + input_act_user.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: get_input_act_qspec(quantization_config), + }, + _annotated=True, + ) + + hidden_state = input_nodes[1] + hidden_state_user = next(iter(hidden_state.users.keys())) + assert isinstance(hidden_state, Node) + assert isinstance(hidden_state_user, Node) + hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + hidden_state: get_input_act_qspec(quantization_config), + }, + _annotated=True, + ) + + assert len(output_nodes) == 2, "expecting GRU to have two outputs" + for output in output_nodes: + output.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), + _annotated=True, + ) + nodes_to_mark_annotated = list(gru_partition.nodes) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + return annotated_partitions + + +@register_annotator("adaptive_avg_pool2d") +def _annotate_adaptive_avg_pool2d( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + """Always annotate adaptive_avg_pool2d op""" + module_partitions = get_source_partitions( + gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn + ) + partitions = list(itertools.chain.from_iterable(module_partitions.values())) + annotated_partitions = [] + for partition in partitions: + pool_node = partition.output_nodes[0] + if ( + pool_node.op != "call_function" + or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default + ): + raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator") + + if _is_annotated([pool_node]): + continue + + annotated_partitions.append(partition.nodes) + input_act = pool_node.args[0] + assert isinstance(input_act, Node) + + # only annotate input output sharing operator + # when the output of the input node is annotated + if ( + "quantization_annotation" not in input_act.meta + or not input_act.meta["quantization_annotation"]._annotated + or input_act.meta["quantization_annotation"].output_qspec is None + ): + input_act_qspec = get_input_act_qspec(quantization_config) + else: + input_act_qspec = SharedQuantizationSpec(input_act) + + # output sharing with input + output_act_qspec = SharedQuantizationSpec((input_act, pool_node)) + pool_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: input_act_qspec, + }, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +def _is_input_large_scalar(node: Node, gm: torch.fx.GraphModule): + """Check if input is a large scalar value. So that we can skip quantization for the node + since histc op (in HistogramObserver) only works for values up to certain upper bound + """ + if node.op == "get_attr": + qualified_name = str(node.target) + module_path, _, name = qualified_name.rpartition(".") + submod = gm.get_submodule(module_path) + tensor = getattr(submod, name) + # torch.histc works until this upper bound + HISTC_UPPER_BOUND = 3.4028235e15 + return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND + return False + + +def _is_input_non_float_tensor(node: Node): + """Check if the input is not a float tensor, so that we can skip quantization for the node + since observers only works with float Tensors + """ + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): + return True + return node.meta["val"].dtype != torch.float32 + + +@register_annotator("add_relu") +def _annotate_add_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_add = node.args[0] + if ( + not isinstance(maybe_add, Node) + or maybe_add.op != "call_function" + or maybe_add.target + not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ] + ): + continue + + add_node = maybe_add + + if len(add_node.users) > 1: + # add can't be fused with ReLU if the result of add is being used + # else where in the graph + continue + + partition = [relu_node, add_node] + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + partition.append(input_act0) + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + partition.append(input_act1) + input_qspec_map[input_act1] = input_act_qspec + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("add") +def _annotate_add( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ]: + continue + add_node = node + partition = [add_node] + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + partition.append(input_act0) + + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + partition.append(input_act1) + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("mul_relu") +def _annotate_mul_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_mul = node.args[0] + if ( + not isinstance(maybe_mul, Node) + or maybe_mul.op != "call_function" + or maybe_mul.target + not in [ + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul_.Tensor, + ] + ): + continue + + mul_node = maybe_mul + if len(mul_node.users) > 1: + # mul can't be fused with ReLU if the result of mul is being used + # else where in the graph + continue + + partition = [relu_node, mul_node] + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = mul_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + partition.append(input_act0) + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = mul_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + partition.append(input_act1) + input_qspec_map[input_act1] = input_act_qspec + + mul_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("mul") +def _annotate_mul( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul_.Tensor, + ]: + continue + + mul_node = node + partition = [mul_node] + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = mul_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + partition.append(input_act0) + + input_act1 = mul_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + partition.append(input_act0) + + mul_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition) + return annotated_partitions + + +# TODO: remove Optional in return type, fix annotated_partitions logic +@register_annotator("cat") +def _annotate_cat( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[list[list[Node]]]: + cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn) + cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values())) + annotated_partitions = [] + for cat_partition in cat_partitions: + cat_node = cat_partition.output_nodes[0] + if _is_annotated([cat_node]): + continue + + if cat_node.target != torch.ops.aten.cat.default: + # TODO: change this to AnnotationException + raise Exception( # noqa: TRY002 + f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}" + " please check if you are calling the correct capture API" + ) + + annotated_partitions.append(cat_partition.nodes) + + input_act_qspec = get_input_act_qspec(quantization_config) + inputs = cat_node.args[0] + + input_qspec_map = {} + input_act0 = inputs[0] # type: ignore[index] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qspec + + shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type] + for input_act in inputs[1:]: # type: ignore[index, union-attr] + if input_act not in input_qspec_map: + input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index] + + output_act_qspec = shared_with_input0_qspec + + cat_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +def _is_share_obs_or_fq_op(op: Callable) -> bool: + return op in [ + torch.ops.aten.relu.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.mean.default, + torch.ops.aten.mean.dim, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dim, + # TODO: remove? + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.view_copy.default, + torch.ops.aten.view.default, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.flatten.using_ints, + ] + + +def propagate_annotation(model: torch.fx.GraphModule) -> None: + for n in model.graph.nodes: + if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target): + continue + + prev_node = n.args[0] + if not isinstance(prev_node, Node): + continue + + quantization_annotation = prev_node.meta.get("quantization_annotation", None) + if not quantization_annotation: + continue + + output_qspec = quantization_annotation.output_qspec + if not output_qspec: + continue + + # make sure current node is not annotated + if ( + "quantization_annotation" in n.meta + and n.meta["quantization_annotation"]._annotated + ): + continue + + shared_qspec = SharedQuantizationSpec(prev_node) + # propagate the previous output_qspec to the current node + n.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + prev_node: shared_qspec, + }, + output_qspec=shared_qspec, + _annotated=True, + ) + + +# TODO: make the list of ops customizable +def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in model.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + ]: + continue + args = list(n.args) + new_args = [] + for i in range(len(args)): + if isinstance(args[i], torch.fx.Node): + new_args.append(args[i]) + continue + prefix = "_tensor_constant_" + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + tensor_constant_name = get_new_attr_name(model) + float_tensor = torch.tensor(float(args[i])) + model.register_buffer(tensor_constant_name, float_tensor) + fake_mode = n.meta["val"].fake_mode + with model.graph.inserting_before(n): + get_attr_node = model.graph.create_node( + "get_attr", tensor_constant_name, (), {} + ) + get_attr_node.meta["val"] = fake_mode.from_tensor( + float_tensor, static_shapes=True + ) + new_args.append(get_attr_node) + n.args = tuple(new_args) + model.recompile() + return model diff --git a/torchao/quantization/pt2e_flow/quantizer/xpu_inductor_quantizer.py b/torchao/quantization/pt2e_flow/quantizer/xpu_inductor_quantizer.py new file mode 100644 index 0000000000..2475190f9c --- /dev/null +++ b/torchao/quantization/pt2e_flow/quantizer/xpu_inductor_quantizer.py @@ -0,0 +1,125 @@ +# mypy: allow-untyped-defs +import functools +from typing import TYPE_CHECKING, Any, Optional + +import torch +from torch.fx import Node + +from torchao.quantization.pt2e_flow.observer import ( + HistogramObserver, + PerChannelMinMaxObserver, +) +from torchao.quantization.pt2e_flow.quantizer.quantizer import QuantizationSpec +from torchao.quantization.pt2e_flow.quantizer.x86_inductor_quantizer import ( + FilterFn, + X86InductorQuantizer, + _is_any_annotated, + int8_in_int8_out_ops, +) +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer_utils import ( + QuantizationConfig, +) + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + +__all__ = [ + "XPUInductorQuantizer", + "get_default_xpu_inductor_quantization_config", +] + + +@functools.lru_cache +def get_default_xpu_inductor_quantization_config(): + extra_args: dict[str, Any] = {"eps": 2**-12} + act_observer_or_fake_quant_ctr = HistogramObserver + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + PerChannelMinMaxObserver + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + bias_quantization_spec = None # will use placeholder observer by default + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + False, + ) + return quantization_config + + +class XPUInductorQuantizer(X86InductorQuantizer): + """ + XPUInductorQuantizer is a class designed to facilitate + quantization capability at Intel GPU backend. The class + highly reuses the existing implementation of + X86InductorQuantizer as both are intended to take advantage + of the optimized kernels in oneDNN library. + """ + + def __init__(self) -> None: + super().__init__() + + """ + Following annotate_xx overrides the impls in base class, as + no XPU implementation for these operators currently. We would + gradually enable the XPU implementation and remove following + overrides. We keep the annotate methods but make the function + body empty, aiming to let `_generate_qdq_quantized_model` + generate qdq around op and graph execute on fp32 dtype for + unspported operators. + """ + + def _annotate_qat_conv2d_fusion_pattern( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[FilterFn] = None, + ): + pass + + def _annotate_maxpool2d( + self, + node: Node, + quantization_config: Optional[QuantizationConfig], + ) -> None: + """ + Here we skip the annotate logic for maxpool at XPU backend + as the quantized::max_pool2d is only implemented for CPU. + """ + return + + def _annotate_output_for_int8_in_int8_out_pattern( + self, + node: Node, + ) -> None: + if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): + if node.target == torch.ops.aten.max_pool2d.default: + return + else: + input_node = node.all_input_nodes[0] + self._annotate_output_share_observer_as_input(input_node, node) + return diff --git a/torchao/quantization/pt2e_flow/utils.py b/torchao/quantization/pt2e_flow/utils.py new file mode 100644 index 0000000000..a1b2652632 --- /dev/null +++ b/torchao/quantization/pt2e_flow/utils.py @@ -0,0 +1,822 @@ +# mypy: allow-untyped-defs +""" +Utils shared by different modes of quantization (eager/graph) +""" + +import functools +import warnings +from collections import OrderedDict +from inspect import getfullargspec, signature +from typing import Any, Callable, Optional, Union + +import torch +from torch.ao.quantization.quant_type import QuantType +from torch.fx import Node +from torch.nn.utils.parametrize import is_parametrized + +NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] +NodePattern.__module__ = "torch.ao.quantization.utils" + +# This is the Quantizer class instance from torch/quantization/fx/quantize.py. +# Define separately to prevent circular imports. +# TODO(future PR): improve this. +# make this public once fixed (can't be public as is because setting the module directly +# doesn't work) +QuantizerCls = Any + +# Type for fusion patterns, it can be more complicated than the following actually, +# see pattern.md for docs +# TODO: not sure if typing supports recursive data types +Pattern = Union[ + Callable, tuple[Callable, Callable], tuple[Callable, tuple[Callable, Callable]], Any +] +Pattern.__module__ = "torch.ao.quantization.utils" + + +# TODO: maybe rename this to MatchInputNode +class MatchAllNode: + """A node pattern that matches all nodes, used in defining + fusion patterns in FX Graph Mode Quantization + """ + + +module_type_list = { + torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.Identity, + torch.nn.Hardsigmoid, + torch.nn.Sigmoid, + torch.nn.Tanh, +} +func_list = { + torch.nn.functional.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.instance_norm, + torch.nn.functional.layer_norm, + torch.nn.functional.leaky_relu, + torch.nn.functional.silu, + torch.nn.functional.mish, + torch.nn.functional.dropout, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.relu, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.hardsigmoid, + torch.nn.functional.sigmoid, + torch.transpose, + torch.repeat_interleave, + torch.sigmoid, + torch.squeeze, + torch.stack, + torch.sum, + torch.tanh, + torch.unsqueeze, + torch.cat, +} +method_list = { + torch.mean, + "relu", + "relu_", + "contiguous", + "detach", + "detach_", + "hardsigmoid", + "hardsigmoid_", + "permute", + "repeat", + "repeat_interleave", + "reshape", + "resize_", + "shape", + "sigmoid", + "sigmoid_", + "size", + "squeeze", + "squeeze_", + "tanh", + "tanh_", + "transpose", + "unsqueeze", + "unsqueeze_", + "view", +} + + +# TODO: not used now, remove +def check_node(node, modules): + # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py + is_call_function = node.op == "call_function" and node.target in func_list + is_call_method = node.op == "call_method" and node.target in method_list + is_call_module = ( + node.op == "call_module" and type(modules[str(node.target)]) in module_type_list + ) + return is_call_function, is_call_method, is_call_module + + +def get_combined_dict(default_dict, additional_dict): + """ + Combines two dictionaries. + + This function takes two dictionaries as input and returns a new dictionary + that contains all the key-value pairs from both input dictionaries. + If there are any duplicate keys in the `additional_dict`, the values + from the `additional_dict` will overwrite those in the `default_dict`. + Args: + default_dict (dict): The main dictionary that will be used as the base + additional_dict (dict): The dictionary used to update `default_dict` + + Returns: + dict: The resulting dictionary + Example: + >>> x = dict(a=1, b=1) + >>> y = dict(b=2, c=3) + >>> get_combined_dict(x, y) + {'a': 1, 'b': 2, 'c': 3} + """ + d = default_dict.copy() + d.update(additional_dict) + return d + + +def is_per_tensor(qscheme): + return qscheme == torch.per_tensor_affine or qscheme == torch.per_tensor_symmetric + + +def is_per_channel(qscheme): + return qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + torch.per_channel_symmetric, + ] + + +def getattr_from_fqn(obj: Any, fqn: str) -> Any: + """ + Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz. + """ + return functools.reduce(getattr, fqn.split("."), obj) + + +def to_underlying_dtype(qdtype): + DTYPE_MAPPING = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + torch.quint4x2: torch.uint8, + torch.quint2x4: torch.uint8, + torch.uint8: torch.uint8, + torch.int8: torch.int8, + torch.uint16: torch.uint16, + torch.int16: torch.int16, + torch.int32: torch.int32, + torch.float8_e5m2: torch.float8_e5m2, + torch.float8_e4m3fn: torch.float8_e4m3fn, + } + assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype) + return DTYPE_MAPPING[qdtype] + + +def get_qparam_dict(observer_or_fake_quant): + from torch.ao.quantization.observer import PlaceholderObserver + + qscheme = getattr(observer_or_fake_quant, "qscheme", None) + dtype = observer_or_fake_quant.dtype + qparams = {"qscheme": qscheme, "dtype": dtype} + + if not qscheme or isinstance(observer_or_fake_quant, PlaceholderObserver): + return {"qscheme": None, "dtype": dtype} + + if is_per_tensor(qscheme): + qscheme = torch.per_tensor_affine + elif is_per_channel(qscheme): + # change symmetric to affine since we do not have symmetric + # quantized Tensor + if qscheme == torch.per_channel_symmetric: + qscheme = torch.per_channel_affine + qparams["axis"] = observer_or_fake_quant.ch_axis + else: + raise RuntimeError(f"Unrecognized qscheme: {qscheme}") + # update qscheme, since we don't have symmetric quant qscheme + # in quantized Tensor + qparams["qscheme"] = qscheme + + scale, zero_point = observer_or_fake_quant.calculate_qparams() + qparams["scale"] = scale + qparams["zero_point"] = zero_point + + if hasattr(observer_or_fake_quant, "quant_min"): + qparams["quant_min"] = observer_or_fake_quant.quant_min + if hasattr(observer_or_fake_quant, "quant_max"): + qparams["quant_max"] = observer_or_fake_quant.quant_max + + return qparams + + +def get_swapped_custom_module_class( + custom_module, custom_module_class_mapping, qconfig +): + """Get the observed/quantized custom module class that we need + to swap `custom_module` to + Input: + custom_module: input, can be an instance of either a float or observed custom module + custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping + qconfig: qconfig configured for the custom module + + Output: + corresponding observed/quantized custom module class for input custom module instance + """ + quant_type = get_quant_type(qconfig) + class_mapping = custom_module_class_mapping.get(quant_type, {}) + assert type(custom_module) in class_mapping, ( + "did not find corresponding observed " + f"module class for {type(custom_module)} in mapping: {class_mapping}" + ) + return class_mapping[type(custom_module)] + + +def activation_dtype(qconfig): + assert qconfig is not None + activation = qconfig.activation() + return activation.dtype + + +def weight_dtype(qconfig): + assert qconfig is not None + weight = qconfig.weight() + return weight.dtype + + +def activation_is_statically_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16 + """ + return activation_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float16, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] and (not activation_is_dynamically_quantized(qconfig)) + + +def activation_is_dynamically_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + dynamically quantized or not, this includes dynamically quantizing to + quint8, qint8 and float16 + """ + _activation_dtype, _, activation_is_dynamic = get_qconfig_dtypes(qconfig) + return activation_is_dynamic + + +def activation_is_int8_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + quantized to int8 or not, this includes quantizing to quint8, qint8 + """ + return activation_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.uint8, + torch.int8, + ] + + +def activation_is_int32_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + quantized to int32 or not + """ + return activation_dtype(qconfig) in [torch.qint32, torch.int32] + + +def weight_is_quantized(qconfig): + """Given a qconfig, decide if the weight needs to be + quantized or not + """ + return weight_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.float16, + torch.quint4x2, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] + + +def weight_is_statically_quantized(qconfig): + """Given a qconfig, decide if the weight needs to be statically + quantized or not + """ + return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8] + + +def op_is_int8_dynamically_quantized(qconfig) -> bool: + """Given a qconfig, returns True if this op is using int8 dynamic + quantization + """ + activation_dtype, weight_dtype, activation_is_dynamic = get_qconfig_dtypes(qconfig) + return ( + activation_dtype in [torch.quint8, torch.uint8] + and + # for now, the lines below assume fbgemm or qnnpack + weight_dtype in [torch.qint8, torch.int8] + and activation_is_dynamic + ) + + +def get_qconfig_dtypes(qconfig): + r"""returns the qconfig tuple for qconfig: + (activation_dtype, weight_dtype, activation_is_dynamic) + """ + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + act_is_dynamic = getattr(activation, "is_dynamic", False) + return (activation.dtype, weight.dtype, act_is_dynamic) + + +def get_quant_type(qconfig): + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + static_dtypes = [ + torch.quint8, + torch.qint8, + torch.quint4x2, + torch.qint32, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] + if weight.dtype in static_dtypes: + if hasattr(activation, "is_dynamic") and activation.is_dynamic: + return QuantType.DYNAMIC + elif activation.dtype in static_dtypes: + return QuantType.STATIC + else: + return QuantType.WEIGHT_ONLY + + if weight.dtype == torch.float16: + if hasattr(activation, "is_dynamic") and activation.is_dynamic: + return QuantType.DYNAMIC + elif activation.dtype == torch.float16: + return QuantType.STATIC + + raise Exception( # noqa: TRY002 + f"Unrecognized dtype combination in get_quant_type: activation({activation.dtype})," + f"weight({weight.dtype})" + ) + + +def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool: + """Checks if the given minimum and maximum values are valid, meaning that + they exist and the min value is less than the max value. + """ + if min_val.numel() == 0 or max_val.numel() == 0: + warnings.warn( + "must run observer before calling calculate_qparams. " + + "Returning default values." + ) + return False + + if min_val.dim() == 0 or max_val.dim() == 0: + if min_val == float("inf") and max_val == float("-inf"): + warnings.warn( + "must run observer before calling calculate_qparams. " + + "Returning default values." + ) + + return False + + assert min_val <= max_val, f"min {min_val} should be less than max {max_val}" + else: + assert torch.all( + min_val <= max_val + ), f"min {min_val} should be less than max {max_val}" + + return True + + +def calculate_qmin_qmax( + quant_min: int, + quant_max: int, + has_customized_qrange: bool, + dtype: torch.dtype, + reduce_range: bool, +) -> tuple[int, int]: + r"""Calculates actual qmin and qmax based on the quantization range, + observer datatype and if range is reduced. + """ + # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted. + if has_customized_qrange: + # This initialization here is to be resolve TorchScript compilation issues and allow + # using of refinement to decouple initial_qmin and initial_qmax from quantization range. + # The actual values of initial_qmin and initial_qmax will be reset below. + if dtype in [torch.qint32, torch.int32]: + initial_quant_min, initial_quant_max = 0, 2**32 - 1 + else: + initial_quant_min, initial_quant_max = 0, 255 + # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the + # attribute from Optional valid integers for use, based on TorchScript's requirements. + custom_quant_min, custom_quant_max = quant_min, quant_max + if custom_quant_min is not None and custom_quant_max is not None: + initial_quant_min, initial_quant_max = ( + custom_quant_min, + custom_quant_max, + ) + + qrange_len = initial_quant_max - initial_quant_min + 1 + if dtype in [torch.qint8, torch.int8]: + assert ( + 0 < qrange_len <= 256 + ), "quantization range should be positive and not exceed the maximum bit range (=256)." + elif dtype in [torch.qint32, torch.int32]: + assert ( + 0 < qrange_len <= 2**32 + ), "quantization range should be positive and not exceed the maximum bit range (=4294967296)." + if reduce_range: + quant_min, quant_max = quant_min // 2, quant_max // 2 + else: + # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used. + if dtype in [torch.qint8, torch.int8]: + if reduce_range: + quant_min, quant_max = -64, 63 + else: + quant_min, quant_max = -128, 127 + elif dtype in [torch.quint8, torch.uint8]: + if reduce_range: + quant_min, quant_max = 0, 127 + else: + quant_min, quant_max = 0, 255 + elif dtype in [torch.qint32, torch.int32]: + quant_min, quant_max = -1 * (2**31), (2**31) - 1 + elif dtype in [torch.uint16]: + quant_min, quant_max = 0, 2**16 - 1 + elif dtype in [torch.int16]: + quant_min, quant_max = -(2**15), 2**15 - 1 + else: + quant_min, quant_max = 0, 15 + return quant_min, quant_max + + +def _parent_name(target): + """ + Turn 'foo.bar' into ['foo', 'bar'] + """ + r = target.rsplit(".", 1) + if len(r) == 1: + return "", r[0] + else: + return r[0], r[1] + + +def has_no_children_ignoring_parametrizations(module): + """ + Checks if module._modules is empty or + if module is a parametrization, checks that module._modules only has + the 'parametrizations' module + """ + if len(module._modules) == 0: + return True + elif is_parametrized(module): + return len(module._modules) == 1 and "parametrizations" in module._modules + else: + return False + + +def _get_path_of_module( + root: torch.nn.Module, submodule: torch.nn.Module +) -> Optional[str]: + """Get the path (fully qualified name) of a submodule + + Example:: + + >> class M(torch.nn.Module): + def __init__(self) -> None: + self.linear = torch.nn.Linear(5, 5) + def forward(self, x): + return self.linear(x) + + >> m = M() + >> l = m.linear + >> _get_path_of_module(m, l) + "linear" + """ + for n, p in root.named_modules(): + if submodule is p: + return n + return None + + +def _get_signature_locals(f: Callable, loc: dict[str, Any]) -> dict[str, Any]: + """Get local keyword arguments + + Example:: + + >> def f(self, a, b=9): + pass + >> loc = {"a": 6, "c": 7} + >> _get_signature_locals(f, loc) + {"a": 6} + """ + return {k: v for k, v in loc.items() if k in signature(f).parameters} + + +def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]": + """Get all default keyword arguments from function signature + + Example:: + + >> def f(self, a, b=9): + pass + >> _get_default_kwargs(f) + {"b": 9} + """ + kwargs = {} + for name, param in signature(f).parameters.items(): + if param.default is not param.empty: + kwargs[name] = param.default + elif param.kind is param.VAR_POSITIONAL: + kwargs[name] = () + elif param.kind is param.VAR_KEYWORD: + kwargs[name] = {} + return OrderedDict(kwargs) + + +def _normalize_kwargs(func: Callable, loc: dict[str, Any]) -> "OrderedDict[str, Any]": + """Given a function and local function arguments, normalize the keyword + arguments by filling in default arguments from function signature + + Example:: + + >> def f(self, key1=3, key2=3): + pass + >> loc = {"key2": 6} + >> _normalize_kwargs(f, loc) + {"key1": 3, "key2": 6} + """ + default_kwargs = _get_default_kwargs(func) + local_kwargs = _get_signature_locals(func, loc) + normalized_kwargs = default_kwargs.copy() + for attr, val in local_kwargs.items(): + if attr in normalized_kwargs: + # override the default keyword arguments + normalized_kwargs[attr] = val + return normalized_kwargs + + +def validate_qmin_qmax(quant_min: int, quant_max: int) -> None: + r"""Validates that the user-specified quantization range is properly initialized + and within the given bound supported by the observer dtype. + + To accommodate lower-bit quantization with respect to the existing torch.qint8 and + torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing + in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax + values are used to calculate static estimates of the scale and zero point for aggressive lower-bit + fake quantization. These estimates are compared against parameters learned through backpropagation. + The related literatures for scale and zero point via backpropagation are as follows: + + Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS + Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf + """ + # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted + # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. + assert ( + quant_min <= 0 <= quant_max + ), "Used-specified quantization range must include 0." + assert ( + quant_min < quant_max + ), "qmin must be strictly less than qmax for user-specified quantization range." + + +# Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme +# as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer +# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikey to change +# (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168) +def determine_qparams( + min_val: torch.Tensor, + max_val: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + eps: torch.Tensor, + has_customized_qrange: bool, + qscheme: torch.qscheme = torch.per_tensor_affine, +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters, given min and max + value tensors. Works for both per tensor and per channel cases + + Args: + min_val: Minimum values per channel + max_val: Maximum values per channel + + Returns: + scales: Scales tensor of shape (#channels,) + zero_points: Zero points tensor of shape (#channels,) + """ + if not check_min_max_valid(min_val, max_val): + return torch.tensor([1.0], device=min_val.device.type), torch.tensor( + [0], device=min_val.device.type + ) + + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + device = min_val_neg.device + scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + eps = eps.to(device) + + if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric: + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = torch.max(scale, eps) + if dtype in [torch.uint8, torch.quint8]: + if has_customized_qrange: + # When customized quantization range is used, down-rounded midpoint of the range is chosen. + zero_point = zero_point.new_full( + zero_point.size(), (quant_min + quant_max) // 2 + ) + else: + zero_point = zero_point.new_full(zero_point.size(), 128) + elif qscheme == torch.per_channel_affine_float_qparams: + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.where(scale > eps, scale, torch.ones_like(scale)) + # We use the quantize function + # xq = Round(Xf * inv_scale + zero_point), + # setting zero_point to (-1 * min *inv_scale) we get + # Xq = Round((Xf - min) * inv_scale) + zero_point = -1 * min_val / scale + else: + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, eps) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # For scalar values, cast them to Tensors of size 1 to keep the shape + # consistent with default values in FakeQuantize. + if len(scale.shape) == 0: + # TODO: switch to scale.item() after adding JIT support + scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) + if len(zero_point.shape) == 0: + # TODO: switch to zero_point.item() after adding JIT support + zero_point = torch.tensor( + [int(zero_point)], dtype=zero_point.dtype, device=device + ) + if qscheme == torch.per_channel_affine_float_qparams: + zero_point = torch.tensor( + [float(zero_point)], dtype=zero_point.dtype, device=device + ) + + return scale.to(torch.double), zero_point.to(torch.int64) + + +def _get_num_pos_args(f: Callable) -> int: + """Get number of positional args for a function + + Example:: + + >> def f(self, key1=3, key2=3): + pass + >> _get_num_pos_args(f) + 3 + """ + return len(getfullargspec(f).args) + + +def get_fqn_to_example_inputs( + model: torch.nn.Module, example_inputs: tuple[Any, ...] +) -> dict[str, tuple[Any, ...]]: + """Given a model and its example inputs, return a dictionary from + fully qualified name of submodules to example_inputs for that submodule, + e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,), + "sub.linear1": (tensor4,), ...} + + Used to make quantizing submodules easier now that FX Graph Mode Quantization requires + example inputs. + + Also works for keyword arguments with default values, we would flatten keyword + arguments as positional arguments and fill in the missing keyword args with default + values, e.g. if we have a forward function: + def forward(self, x, key1=3, key2=3): + ... + + and we call it with self.submodule(x, key2=6) + we'll get example_inputs: (x, 3, 6) + + user can also override `key1` with positional arguments as well: + for self.submodule(x, 5, key2=6) + we'll get: (x, 5, 6) + + variable positional arguments and variable positional keyword arguments in forward + function are not supported currently, so please make sure no submodules is using + them. + """ + root = model + fqn_to_example_inputs = {} + + def _patched_module_call(self, *args, **kwargs): + submodule_example_inputs = list(args).copy() + normalized_kwargs = _normalize_kwargs(self.forward, kwargs) + # minus 1 to skipping counting `self` + num_args = _get_num_pos_args(self.forward) - 1 + num_to_pop = num_args - len(submodule_example_inputs) + while num_to_pop and normalized_kwargs: + normalized_kwargs.popitem(last=False) + num_to_pop -= 1 + submodule_example_inputs.extend(normalized_kwargs.values()) + submodule_example_inputs_tuple = tuple(submodule_example_inputs) + fqn = _get_path_of_module(root, self) + if fqn is not None: + fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple + return orig_module_call(self, *args, **kwargs) + + orig_module_call = torch.nn.Module.__call__ + torch.nn.Module.__call__ = _patched_module_call # type: ignore[method-assign] + try: + model(*example_inputs) + finally: + # restore the module call even if there is an exception + torch.nn.Module.__call__ = orig_module_call # type: ignore[method-assign] + return fqn_to_example_inputs + + +def _assert_and_get_unique_device(module: torch.nn.Module) -> Any: + """ + Returns the unique device for a module, or None if no device is found. + Throws an error if multiple devices are detected. + """ + devices = {p.device for p in module.parameters()} | { + p.device for p in module.buffers() + } + """ + As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564 + """ + if {torch.device("cpu"), torch.device("meta")} == devices: + warnings.warn( + "Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'." + ) + devices = {torch.device("cpu")} + "" + assert len(devices) <= 1, ( + "prepare only works with cpu or single-device CUDA modules, " + f"but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + return device + + +__all__ = [ + "NodePattern", + "Pattern", + "MatchAllNode", + "check_node", + "get_combined_dict", + "is_per_tensor", + "is_per_channel", + "getattr_from_fqn", + "get_qparam_dict", + "get_swapped_custom_module_class", + "activation_dtype", + "weight_dtype", + "activation_is_statically_quantized", + "activation_is_dynamically_quantized", + "activation_is_int8_quantized", + "activation_is_int32_quantized", + "weight_is_quantized", + "weight_is_statically_quantized", + "op_is_int8_dynamically_quantized", + "get_qconfig_dtypes", + "get_quant_type", + "check_min_max_valid", + "calculate_qmin_qmax", + "has_no_children_ignoring_parametrizations", + "get_fqn_to_example_inputs", + "to_underlying_dtype", + "determine_qparams", + "validate_qmin_qmax", +] diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 05be8c5c30..cee17c2487 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -197,7 +197,7 @@ class TorchAODType(Enum): _ONES_TABLE = [_n_ones(i) for i in range(8)] -quant_lib = torch.library.Library("quant", "FRAGMENT") +quant_lib = torch.library.Library("torchao_quant", "FRAGMENT") register_custom_op = _register_custom_op(quant_lib) diff --git a/torchao/testing/pt2e/utils.py b/torchao/testing/pt2e/utils.py new file mode 100644 index 0000000000..4597cc7b77 --- /dev/null +++ b/torchao/testing/pt2e/utils.py @@ -0,0 +1,158 @@ +import copy + +import torch +from torch.ao.quantization.backend_config import ( + get_executorch_backend_config, +) +from torch.ao.quantization.quantize_fx import ( + _convert_to_reference_decomposed_fx, + prepare_fx, +) +from torch.export import export_for_training +from torch.testing._internal.common_quantization import ( + NodeSpec, + QuantizationTestCase, +) + +from torchao.quantization.pt2e_flow.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, +) + + +class PT2EQuantizationTestCase(QuantizationTestCase): + """ + Base QuantizationTestCase for PT2 with some helper methods. + """ + + _MAP_TO_FX_TRACED_OPS = { + torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_channel: torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_channel: torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + } + + def _test_quantizer( + self, + model, + example_inputs, + quantizer, + expected_node_occurrence, + expected_node_list=None, + check_against_fx_quant=False, + # TODO: remove the test if fx quant is removed from pytorch + fx_qconfig_mapping=None, + export_with_dynamic_shape=False, + is_qat=False, + is_debug_mode=False, + training_ir_node_occurrence=None, + ): + # resetting dynamo cache + torch._dynamo.reset() + m_eager = model.eval() + + # program capture + m = copy.deepcopy(m_eager) + dynamic_shapes = tuple( + {0: torch.export.Dim("dim")} if i == 0 else None + for i in range(len(example_inputs)) + ) + m = export_for_training( + m, + example_inputs, + dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + ).module() + + if is_qat: + m = prepare_qat_pt2e(m, quantizer) + else: + m = prepare_pt2e(m, quantizer) + if is_debug_mode: + print("prepared model:", m) + # Calibrate + m(*example_inputs) + m = convert_pt2e(m) + if is_debug_mode: + print("quantized model", m) + + pt2_quant_output = m(*example_inputs) + ns = NodeSpec + node_occurrence = { + ns.call_function(k): v for k, v in expected_node_occurrence.items() + } + if expected_node_list is None: + expected_node_list = [] + node_list = [ns.call_function(n) for n in expected_node_list] + self.checkGraphModuleNodes( + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list + ) + if check_against_fx_quant: + qconfig_mapping = fx_qconfig_mapping + backend_config = get_executorch_backend_config() + m_copy = copy.deepcopy(m_eager) + m_fx = prepare_fx( + m_copy, qconfig_mapping, example_inputs, backend_config=backend_config + ) + m_fx(*example_inputs) + m_fx = _convert_to_reference_decomposed_fx( + m_fx, backend_config=backend_config + ) + m_fx = export_for_training( + m_fx, + example_inputs, + dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + ).module() + node_occurrence = {} + for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items(): + if k in expected_node_occurrence: + node_occurrence[ns.call_function(v)] = expected_node_occurrence[k] + if training_ir_node_occurrence is not None: + node_occurrence = { + ns.call_function(k): v + for k, v in training_ir_node_occurrence.items() + } + self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence) + fx_quant_output = m_fx(*example_inputs) + self.assertEqual(fx_quant_output, pt2_quant_output) + return m + + def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): + # resetting dynamo cache + torch._dynamo.reset() + + m = export_for_training( + m, + example_inputs, + ).module() + if is_qat: + m = prepare_qat_pt2e(m, quantizer) + else: + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + return m + + def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule: + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=is_per_channel + ) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + return self._quantize(m, quantizer, example_inputs) From 800404cdb22473dda45ba2670760938a16e27899 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 18 Mar 2025 11:14:02 -0700 Subject: [PATCH 2/6] fix import --- torchao/quantization/pt2e_flow/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/pt2e_flow/__init__.py b/torchao/quantization/pt2e_flow/__init__.py index ffb2972462..edb380da34 100644 --- a/torchao/quantization/pt2e_flow/__init__.py +++ b/torchao/quantization/pt2e_flow/__init__.py @@ -9,12 +9,12 @@ FakeQuantize, FakeQuantizeBase, FixedQParamsFakeQuantize, + FusedMovingAvgObsFakeQuantize, enable_fake_quant, enable_observer, ) from .observer import ( FixedQParamsObserver, - FusedMovingAvgObsFakeQuantize, Granularity, HistogramObserver, MappingType, From d7933a09b489edb89980d5283fd3685c6efbf01b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 18 Mar 2025 22:32:24 -0700 Subject: [PATCH 3/6] add pytorch version filter --- test/quantization/pt2e_flow/test_duplicate_dq.py | 2 ++ test/quantization/pt2e_flow/test_graph_utils.py | 2 ++ test/quantization/pt2e_flow/test_metadata_porting.py | 2 ++ test/quantization/pt2e_flow/test_numeric_debugger.py | 2 ++ test/quantization/pt2e_flow/test_quantize_pt2e.py | 4 ++++ test/quantization/pt2e_flow/test_quantize_pt2e_qat.py | 5 +++++ test/quantization/pt2e_flow/test_representation.py | 3 +++ test/quantization/pt2e_flow/test_x86inductor_quantizer.py | 3 +++ test/quantization/pt2e_flow/test_xnnpack_quantizer.py | 4 ++++ torchao/quantization/pt2e_flow/__init__.py | 2 +- 10 files changed, 28 insertions(+), 1 deletion(-) diff --git a/test/quantization/pt2e_flow/test_duplicate_dq.py b/test/quantization/pt2e_flow/test_duplicate_dq.py index faec98e589..472a7f81f1 100644 --- a/test/quantization/pt2e_flow/test_duplicate_dq.py +++ b/test/quantization/pt2e_flow/test_duplicate_dq.py @@ -28,6 +28,7 @@ OP_TO_ANNOTATOR, QuantizationConfig, ) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class TestHelperModules: @@ -91,6 +92,7 @@ def forward(self, x): @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestDuplicateDQPass(QuantizationTestCase): def _test_duplicate_dq( self, diff --git a/test/quantization/pt2e_flow/test_graph_utils.py b/test/quantization/pt2e_flow/test_graph_utils.py index 42ac3f244f..b0f6b4f22b 100644 --- a/test/quantization/pt2e_flow/test_graph_utils.py +++ b/test/quantization/pt2e_flow/test_graph_utils.py @@ -11,8 +11,10 @@ get_equivalent_types, update_equivalent_types_dict, ) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestGraphUtils(TestCase): @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") def test_conv_bn_conv_relu(self): diff --git a/test/quantization/pt2e_flow/test_metadata_porting.py b/test/quantization/pt2e_flow/test_metadata_porting.py index c2655d114b..f36edb4fbb 100644 --- a/test/quantization/pt2e_flow/test_metadata_porting.py +++ b/test/quantization/pt2e_flow/test_metadata_porting.py @@ -16,6 +16,7 @@ from torchao.quantization.pt2e_flow.quantizer.xnnpack_quantizer_utils import ( OP_TO_ANNOTATOR, ) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class TestHelperModules: @@ -59,6 +60,7 @@ def _tag_partitions( # TODO: rename to TestPortMetadataPass to align with the util name? @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestMetaDataPorting(QuantizationTestCase): def _test_quant_tag_preservation_through_decomp( self, model, example_inputs, from_node_to_tags diff --git a/test/quantization/pt2e_flow/test_numeric_debugger.py b/test/quantization/pt2e_flow/test_numeric_debugger.py index 3c9a7783a7..bc137e738b 100644 --- a/test/quantization/pt2e_flow/test_numeric_debugger.py +++ b/test/quantization/pt2e_flow/test_numeric_debugger.py @@ -24,9 +24,11 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestNumericDebugger(TorchDynamoTestCase): def _assert_each_node_has_debug_handle(self, model) -> None: def _assert_node_has_debug_handle(node): diff --git a/test/quantization/pt2e_flow/test_quantize_pt2e.py b/test/quantization/pt2e_flow/test_quantize_pt2e.py index 2940586c64..85d155b12c 100644 --- a/test/quantization/pt2e_flow/test_quantize_pt2e.py +++ b/test/quantization/pt2e_flow/test_quantize_pt2e.py @@ -2,6 +2,8 @@ # ruff: noqa: F841 +import unittest + import torch from torch import Tensor from torch.ao.quantization import QConfigMapping @@ -60,9 +62,11 @@ QuantizationConfig, ) from torchao.testing.pt2e.utils import PT2EQuantizationTestCase +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @skipIfNoQNNPACK +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestQuantizePT2E(PT2EQuantizationTestCase): def test_simple_quantizer(self): # TODO: use OP_TO_ANNOTATOR diff --git a/test/quantization/pt2e_flow/test_quantize_pt2e_qat.py b/test/quantization/pt2e_flow/test_quantize_pt2e_qat.py index 2d575dc140..4c86794337 100644 --- a/test/quantization/pt2e_flow/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e_flow/test_quantize_pt2e_qat.py @@ -41,6 +41,7 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class PT2EQATTestCase(QuantizationTestCase): @@ -860,6 +861,7 @@ def test_fold_bn_erases_bn_node(self): @skipIfNoQNNPACK +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base): dim = 1 example_inputs = (torch.randn(1, 3, 5),) @@ -869,6 +871,7 @@ class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base): @skipIfNoQNNPACK +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base): dim = 2 example_inputs = (torch.randn(1, 3, 5, 5),) @@ -1037,6 +1040,7 @@ def validate(self, model: torch.fx.GraphModule): @skipIfNoQNNPACK +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestQuantizePT2EQATModels(PT2EQATTestCase): @skip_if_no_torchvision @skipIfNoQNNPACK @@ -1059,6 +1063,7 @@ def test_qat_mobilenet_v2(self): self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestQuantizeMixQATAndPTQ(QuantizationTestCase): class TwoLinear(torch.nn.Module): def __init__(self) -> None: diff --git a/test/quantization/pt2e_flow/test_representation.py b/test/quantization/pt2e_flow/test_representation.py index 75a8b50906..50ee0cf855 100644 --- a/test/quantization/pt2e_flow/test_representation.py +++ b/test/quantization/pt2e_flow/test_representation.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: quantization"] import copy +import unittest from typing import Any, Optional import torch @@ -20,9 +21,11 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @skipIfNoQNNPACK +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestPT2ERepresentation(QuantizationTestCase): def _test_representation( self, diff --git a/test/quantization/pt2e_flow/test_x86inductor_quantizer.py b/test/quantization/pt2e_flow/test_x86inductor_quantizer.py index 4d8b25ce0b..3c3a4c41ee 100644 --- a/test/quantization/pt2e_flow/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e_flow/test_x86inductor_quantizer.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: quantization"] import copy import itertools +import unittest from enum import Enum import torch @@ -28,6 +29,7 @@ QUANT_ANNOTATION_KEY, X86InductorQuantizer, ) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class NodePosType(Enum): @@ -599,6 +601,7 @@ def _test_quantizer( @skipIfNoInductorSupport +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): @skipIfNoX86 def test_conv2d(self): diff --git a/test/quantization/pt2e_flow/test_xnnpack_quantizer.py b/test/quantization/pt2e_flow/test_xnnpack_quantizer.py index 33b35ffe37..460b3c90ae 100644 --- a/test/quantization/pt2e_flow/test_xnnpack_quantizer.py +++ b/test/quantization/pt2e_flow/test_xnnpack_quantizer.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: mobile"] import copy import operator +import unittest import torch import torch._dynamo as torchdynamo @@ -40,9 +41,11 @@ get_symmetric_quantization_config, ) from torchao.testing.pt2e.utils import PT2EQuantizationTestCase +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @skipIfNoQNNPACK +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestXNNPACKQuantizer(PT2EQuantizationTestCase): def test_conv1d(self): quantizer = XNNPACKQuantizer() @@ -1030,6 +1033,7 @@ def forward(self, x): # TODO: express this using self._test_quantizer, add test for inception_v4 +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase): @skip_if_no_torchvision @skipIfNoQNNPACK diff --git a/torchao/quantization/pt2e_flow/__init__.py b/torchao/quantization/pt2e_flow/__init__.py index edb380da34..eb98410ad4 100644 --- a/torchao/quantization/pt2e_flow/__init__.py +++ b/torchao/quantization/pt2e_flow/__init__.py @@ -14,6 +14,7 @@ enable_observer, ) from .observer import ( + AffineQuantizedObserverBase, FixedQParamsObserver, Granularity, HistogramObserver, @@ -38,7 +39,6 @@ ZeroPointDomain, get_block_size, ) -from .pt2e._affine_quantization import AffineQuantizedObserverBase from .pt2e._numeric_debugger import ( # noqa: F401 CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY, From 17f589f393698a2686c8df6245fb20f0270f5d7c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 19 Mar 2025 00:14:10 -0700 Subject: [PATCH 4/6] fix import --- torchao/quantization/pt2e_flow/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/quantization/pt2e_flow/__init__.py b/torchao/quantization/pt2e_flow/__init__.py index eb98410ad4..f93407db3e 100644 --- a/torchao/quantization/pt2e_flow/__init__.py +++ b/torchao/quantization/pt2e_flow/__init__.py @@ -10,6 +10,7 @@ FakeQuantizeBase, FixedQParamsFakeQuantize, FusedMovingAvgObsFakeQuantize, + default_fake_quant, enable_fake_quant, enable_observer, ) @@ -114,6 +115,7 @@ "TorchAODType", "ZeroPointDomain", "get_block_size", + "default_fake_quant", ] From 00f50650c7f67788f6f6d0b85d0ce25eebf7cb4b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 19 Mar 2025 11:13:23 -0700 Subject: [PATCH 5/6] import --- test/quantization/pt2e_flow/test_duplicate_dq.py | 4 +++- test/quantization/pt2e_flow/test_graph_utils.py | 2 -- test/quantization/pt2e_flow/test_numeric_debugger.py | 4 +++- test/quantization/pt2e_flow/test_quantize_pt2e.py | 4 +++- test/quantization/pt2e_flow/test_quantize_pt2e_qat.py | 4 +++- test/quantization/pt2e_flow/test_representation.py | 4 +++- .../pt2e_flow/test_x86inductor_quantizer.py | 4 +++- test/quantization/pt2e_flow/test_xnnpack_quantizer.py | 5 +++-- torchao/testing/pt2e/utils.py | 10 +++++++++- 9 files changed, 30 insertions(+), 11 deletions(-) diff --git a/test/quantization/pt2e_flow/test_duplicate_dq.py b/test/quantization/pt2e_flow/test_duplicate_dq.py index 472a7f81f1..c462cd76d5 100644 --- a/test/quantization/pt2e_flow/test_duplicate_dq.py +++ b/test/quantization/pt2e_flow/test_duplicate_dq.py @@ -5,7 +5,6 @@ from typing import Any import torch -from torch.export import export_for_training from torch.testing._internal.common_quantization import QuantizationTestCase from torch.testing._internal.common_utils import IS_WINDOWS @@ -30,6 +29,9 @@ ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +if TORCH_VERSION_AT_LEAST_2_5: + from torch.export import export_for_training + class TestHelperModules: class Conv2dWithObsSharingOps(torch.nn.Module): diff --git a/test/quantization/pt2e_flow/test_graph_utils.py b/test/quantization/pt2e_flow/test_graph_utils.py index b0f6b4f22b..42ac3f244f 100644 --- a/test/quantization/pt2e_flow/test_graph_utils.py +++ b/test/quantization/pt2e_flow/test_graph_utils.py @@ -11,10 +11,8 @@ get_equivalent_types, update_equivalent_types_dict, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestGraphUtils(TestCase): @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") def test_conv_bn_conv_relu(self): diff --git a/test/quantization/pt2e_flow/test_numeric_debugger.py b/test/quantization/pt2e_flow/test_numeric_debugger.py index bc137e738b..76e8070971 100644 --- a/test/quantization/pt2e_flow/test_numeric_debugger.py +++ b/test/quantization/pt2e_flow/test_numeric_debugger.py @@ -6,7 +6,6 @@ import torch from torch._dynamo.test_case import TestCase as TorchDynamoTestCase -from torch.export import export_for_training from torch.testing._internal.common_quantization import TestHelperModules from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef @@ -26,6 +25,9 @@ ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +if TORCH_VERSION_AT_LEAST_2_5: + from torch.export import export_for_training + @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") diff --git a/test/quantization/pt2e_flow/test_quantize_pt2e.py b/test/quantization/pt2e_flow/test_quantize_pt2e.py index 85d155b12c..06a0a57b21 100644 --- a/test/quantization/pt2e_flow/test_quantize_pt2e.py +++ b/test/quantization/pt2e_flow/test_quantize_pt2e.py @@ -13,7 +13,6 @@ per_channel_weight_observer_range_neg_127_to_127, weight_observer_range_neg_127_to_127, ) -from torch.export import export_for_training from torch.fx import Node from torch.testing._internal.common_quantization import ( NodeSpec as ns, @@ -64,6 +63,9 @@ from torchao.testing.pt2e.utils import PT2EQuantizationTestCase from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +if TORCH_VERSION_AT_LEAST_2_5: + from torch.export import export_for_training + @skipIfNoQNNPACK @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") diff --git a/test/quantization/pt2e_flow/test_quantize_pt2e_qat.py b/test/quantization/pt2e_flow/test_quantize_pt2e_qat.py index 4c86794337..55b6080d71 100644 --- a/test/quantization/pt2e_flow/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e_flow/test_quantize_pt2e_qat.py @@ -7,7 +7,6 @@ import torch from torch.ao.quantization import QConfigMapping from torch.ao.quantization.quantize_fx import prepare_qat_fx -from torch.export import export_for_training from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_quantization import ( NodeSpec as ns, @@ -43,6 +42,9 @@ ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +if TORCH_VERSION_AT_LEAST_2_5: + from torch.export import export_for_training + class PT2EQATTestCase(QuantizationTestCase): """ diff --git a/test/quantization/pt2e_flow/test_representation.py b/test/quantization/pt2e_flow/test_representation.py index 50ee0cf855..4af535d61d 100644 --- a/test/quantization/pt2e_flow/test_representation.py +++ b/test/quantization/pt2e_flow/test_representation.py @@ -5,7 +5,6 @@ import torch from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 -from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, ) @@ -23,6 +22,9 @@ ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +if TORCH_VERSION_AT_LEAST_2_5: + from torch.export import export_for_training + @skipIfNoQNNPACK @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") diff --git a/test/quantization/pt2e_flow/test_x86inductor_quantizer.py b/test/quantization/pt2e_flow/test_x86inductor_quantizer.py index 3c3a4c41ee..fd6e9210a1 100644 --- a/test/quantization/pt2e_flow/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e_flow/test_x86inductor_quantizer.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn -from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, ) @@ -31,6 +30,9 @@ ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +if TORCH_VERSION_AT_LEAST_2_5: + from torch.export import export_for_training + class NodePosType(Enum): left = 1 diff --git a/test/quantization/pt2e_flow/test_xnnpack_quantizer.py b/test/quantization/pt2e_flow/test_xnnpack_quantizer.py index 460b3c90ae..33ce9e2fc6 100644 --- a/test/quantization/pt2e_flow/test_xnnpack_quantizer.py +++ b/test/quantization/pt2e_flow/test_xnnpack_quantizer.py @@ -24,7 +24,6 @@ convert_to_reference_fx, prepare_fx, ) -from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, ) @@ -43,9 +42,11 @@ from torchao.testing.pt2e.utils import PT2EQuantizationTestCase from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +if TORCH_VERSION_AT_LEAST_2_5: + from torch.export import export_for_training + @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "only works for torch 2.5+") class TestXNNPACKQuantizer(PT2EQuantizationTestCase): def test_conv1d(self): quantizer = XNNPACKQuantizer() diff --git a/torchao/testing/pt2e/utils.py b/torchao/testing/pt2e/utils.py index 4597cc7b77..107178af61 100644 --- a/torchao/testing/pt2e/utils.py +++ b/torchao/testing/pt2e/utils.py @@ -1,4 +1,5 @@ import copy +import unittest import torch from torch.ao.quantization.backend_config import ( @@ -8,7 +9,6 @@ _convert_to_reference_decomposed_fx, prepare_fx, ) -from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec, QuantizationTestCase, @@ -23,8 +23,16 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +if TORCH_VERSION_AT_LEAST_2_5: + from torch.export import export_for_training + +@unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, + "only works for torch 2.5+ since export_for_training is only supported after 2.5", +) class PT2EQuantizationTestCase(QuantizationTestCase): """ Base QuantizationTestCase for PT2 with some helper methods. From 297c990b2c4e250506c209b5c280eb58fd1bd0e1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 19 Mar 2025 22:41:47 -0700 Subject: [PATCH 6/6] import --- torchao/quantization/pt2e_flow/__init__.py | 36 +- torchao/quantization/pt2e_flow/qconfig.py | 699 --------------------- 2 files changed, 18 insertions(+), 717 deletions(-) delete mode 100644 torchao/quantization/pt2e_flow/qconfig.py diff --git a/torchao/quantization/pt2e_flow/__init__.py b/torchao/quantization/pt2e_flow/__init__.py index f93407db3e..0a6f5b9706 100644 --- a/torchao/quantization/pt2e_flow/__init__.py +++ b/torchao/quantization/pt2e_flow/__init__.py @@ -5,6 +5,24 @@ import torch from torch import Tensor +from torchao.quantization.pt2e_flow.pt2e._numeric_debugger import ( # noqa: F401 + CUSTOM_KEY, + NUMERIC_DEBUG_HANDLE_KEY, + compare_results, + extract_results_from_loggers, + generate_numeric_debug_handle, + prepare_for_propagation_comparison, +) +from torchao.quantization.pt2e_flow.pt2e.export_utils import ( + _allow_exported_model_train_eval as allow_exported_model_train_eval, +) +from torchao.quantization.pt2e_flow.pt2e.export_utils import ( + _move_exported_model_to_eval as move_exported_model_to_eval, +) +from torchao.quantization.pt2e_flow.pt2e.export_utils import ( + _move_exported_model_to_train as move_exported_model_to_train, +) + from .fake_quantize import ( FakeQuantize, FakeQuantizeBase, @@ -40,24 +58,6 @@ ZeroPointDomain, get_block_size, ) -from .pt2e._numeric_debugger import ( # noqa: F401 - CUSTOM_KEY, - NUMERIC_DEBUG_HANDLE_KEY, - compare_results, - extract_results_from_loggers, - generate_numeric_debug_handle, - prepare_for_propagation_comparison, -) -from .pt2e.export_utils import ( - _allow_exported_model_train_eval as allow_exported_model_train_eval, -) -from .pt2e.export_utils import ( - _move_exported_model_to_eval as move_exported_model_to_eval, -) -from .pt2e.export_utils import ( - _move_exported_model_to_train as move_exported_model_to_train, -) -from .qconfig import * # noqa: F403 # ensure __module__ is set correctly for public APIs ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] diff --git a/torchao/quantization/pt2e_flow/qconfig.py b/torchao/quantization/pt2e_flow/qconfig.py deleted file mode 100644 index c785ff2b1a..0000000000 --- a/torchao/quantization/pt2e_flow/qconfig.py +++ /dev/null @@ -1,699 +0,0 @@ -# mypy: allow-untyped-defs -import copy -import warnings -from collections import namedtuple -from typing import Any, Optional, Union - -import torch -import torch.nn as nn -from torch.ao.quantization.fake_quantize import ( - FakeQuantize, - FakeQuantizeBase, - FusedMovingAvgObsFakeQuantize, - default_dynamic_fake_quant, - default_embedding_fake_quant, - default_embedding_fake_quant_4bit, - default_fake_quant, - default_fused_act_fake_quant, - default_fused_per_channel_wt_fake_quant, - default_fused_wt_fake_quant, - default_per_channel_weight_fake_quant, - default_weight_fake_quant, - fused_per_channel_wt_fake_quant_range_neg_127_to_127, - fused_wt_fake_quant_range_neg_127_to_127, -) -from typing_extensions import deprecated - -from .observer import ( - HistogramObserver, - MinMaxObserver, - MovingAverageMinMaxObserver, - NoopObserver, - ObserverBase, - PlaceholderObserver, - ReuseInputObserver, - _PartialWrapper, - default_debug_observer, - default_dynamic_quant_observer, - default_float_qparams_observer, - default_float_qparams_observer_4bit, - default_observer, - default_per_channel_weight_observer, - default_placeholder_observer, - default_reuse_input_observer, - default_weight_observer, - per_channel_weight_observer_range_neg_127_to_127, - weight_observer_range_neg_127_to_127, -) - -__all__ = [ - "QConfig", - # TODO: deprecated, remove - "QConfigDynamic", - "default_qconfig", - "default_debug_qconfig", - "default_per_channel_qconfig", - "default_dynamic_qconfig", - "float16_dynamic_qconfig", - "float16_static_qconfig", - "per_channel_dynamic_qconfig", - "float_qparams_weight_only_qconfig", - "float_qparams_weight_only_qconfig_4bit", - "default_quint8_weight_qconfig", - "default_qat_qconfig", - "default_dynamic_qat_qconfig", - "default_weight_only_qconfig", - "default_activation_only_qconfig", - "default_qat_qconfig_v2", - "default_reuse_input_qconfig", - "default_symmetric_qnnpack_qconfig", - "default_per_channel_symmetric_qnnpack_qconfig", - "default_symmetric_qnnpack_qat_qconfig", - "default_per_channel_symmetric_qnnpack_qat_qconfig", - "default_embedding_qat_qconfig", - "default_embedding_qat_qconfig_4bit", - "get_default_qconfig", - "get_default_qat_qconfig", - "get_default_qconfig_dict", - "get_default_qat_qconfig_dict", - "QConfigAny", - "qconfig_equals", -] - - -class QConfig(namedtuple("QConfig", ["activation", "weight"])): - """ - Describes how to quantize a layer or a part of the network by providing - settings (observer classes) for activations and weights respectively. - - - Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns - instances on invocation, not the concrete observer instances themselves. - Quantization preparation function will instantiate observers multiple times for each of the layers. - - - Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` - method (that behaves like functools.partial):: - - my_qconfig = QConfig( - activation=MinMaxObserver.with_args(dtype=torch.qint8), - weight=default_observer.with_args(dtype=torch.qint8)) - - """ - - __slots__ = () - - def __new__(cls, activation, weight): - # catch common mistakes - if isinstance(activation, nn.Module) or isinstance(weight, nn.Module): - raise ValueError( - "QConfig received observer instance, please pass observer class instead. " - + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" - ) - return super().__new__(cls, activation, weight) - - -@deprecated( - "`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead", - category=FutureWarning, -) -class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])): - """ - Describes how to dynamically quantize a layer or a part of the network by providing - settings (observer classes) for weights. - - It's like QConfig, but for dynamic quantization. - - Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns - instances on invocation, not the concrete observer instances themselves. - Quantization function will instantiate observers multiple times for each of the layers. - - Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` - method (that behaves like functools.partial):: - - my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8)) - """ - - __slots__ = () - - def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): - # catch common mistakes - if isinstance(weight, nn.Module): - raise ValueError( - "QConfigDynamic received observer instance, please pass observer class instead. " - + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" - ) - return super().__new__(cls, activation, weight) - - -default_qconfig = QConfig(activation=default_observer, weight=default_weight_observer) -""" -Default qconfig configuration. -""" - -default_debug_qconfig = QConfig( - weight=default_weight_observer, activation=default_debug_observer -) -""" -Default qconfig configuration for debugging. -""" - -default_per_channel_qconfig = QConfig( - activation=default_observer, weight=default_per_channel_weight_observer -) -""" -Default qconfig configuration for per channel weight quantization. -""" - -default_dynamic_qconfig = QConfig( - activation=default_dynamic_quant_observer, weight=default_weight_observer -) -""" -Default dynamic qconfig. -""" - -float16_dynamic_qconfig = QConfig( - activation=PlaceholderObserver.with_args(dtype=torch.float16, is_dynamic=True), - weight=PlaceholderObserver.with_args(dtype=torch.float16), -) -""" -Dynamic qconfig with weights quantized to `torch.float16`. -""" - -float16_static_qconfig = QConfig( - activation=PlaceholderObserver.with_args(dtype=torch.float16), - weight=PlaceholderObserver.with_args(dtype=torch.float16), -) -""" -Dynamic qconfig with both activations and weights quantized to `torch.float16`. -""" - -per_channel_dynamic_qconfig = QConfig( - activation=default_dynamic_quant_observer, - weight=default_per_channel_weight_observer, -) -""" -Dynamic qconfig with weights quantized per channel. -""" - -float_qparams_weight_only_qconfig = QConfig( - activation=default_placeholder_observer, weight=default_float_qparams_observer -) -""" -Dynamic qconfig with weights quantized with a floating point zero_point. -""" - -float_qparams_weight_only_qconfig_4bit = QConfig( - activation=default_placeholder_observer, weight=default_float_qparams_observer_4bit -) - -default_qat_qconfig = QConfig( - activation=default_fake_quant, weight=default_weight_fake_quant -) -""" -Default qconfig for QAT. -""" - -default_dynamic_qat_qconfig = QConfig( - activation=default_dynamic_fake_quant, weight=default_weight_fake_quant -) -""" -Default qconfig for dynamic QAT. -""" - -default_weight_only_qconfig = QConfig( - activation=torch.nn.Identity, weight=default_weight_fake_quant -) -""" -Default qconfig for quantizing weights only. -""" - -default_activation_only_qconfig = QConfig( - activation=default_fake_quant, weight=torch.nn.Identity -) -""" -Default qconfig for quantizing activations only. -""" - -# QAT config that uses a fused observer + fake quant modules for optimized training performance. -# to modify the activation/weight observers, the default entries in fake_quantize.py can be modified. -default_qat_qconfig_v2 = QConfig( - activation=default_fused_act_fake_quant, weight=default_fused_wt_fake_quant -) -""" -Fused version of `default_qat_config`, has performance benefits. -""" - -default_reuse_input_qconfig = QConfig( - activation=default_reuse_input_observer, weight=NoopObserver -) -""" -Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape -""" - - -def get_default_qconfig(backend="x86", version=0): - """ - Returns the default PTQ qconfig for the specified backend. - - Args: - * `backend` (str): a string representing the target backend. Currently supports - `x86` (default), `fbgemm`, `qnnpack` and `onednn`. - - Return: - qconfig - """ - supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] - if backend not in supported_backends: - raise AssertionError( - "backend: " - + str(backend) - + f" not supported. backend must be one of {supported_backends}" - ) - - if version == 0: - if backend == "fbgemm": - qconfig = QConfig( - activation=HistogramObserver.with_args(reduce_range=True), - weight=default_per_channel_weight_observer, - ) - elif backend == "qnnpack": - # TODO: make this compatible with xnnpack constraints - qconfig = QConfig( - activation=HistogramObserver.with_args(reduce_range=False), - weight=default_weight_observer, - ) - elif backend == "onednn": - if not torch.cpu._is_vnni_supported(): - warnings.warn( - "Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues " - "on CPU without Vector Neural Network Instruction support." - ) - qconfig = QConfig( - activation=HistogramObserver.with_args(reduce_range=False), - weight=default_per_channel_weight_observer, - ) - elif backend == "x86": - qconfig = QConfig( - activation=HistogramObserver.with_args(reduce_range=True), - weight=default_per_channel_weight_observer, - ) - else: - # won't reach - qconfig = default_qconfig - else: - raise AssertionError( - "Version number: " - + str(version) - + " in get_default_qconfig is not supported. Version number must be 0" - ) - - return qconfig - - -""" -Default, symmetric PTQ qconfig for the specified backend. And a per_channel -variant of the same. - -Symmetric here applies to signed weights with zero point = 0, and additional -value restrictions. The activations are also signed 8-bit integers with this -qconfig. - - * Once this change is merged [as of 3/17/22], with backend or qengine = - 'qnnpack', some quantized operators with this symmetric qconfig may use - operators from xnnpack library. - - ** Support to use xnnpack ops with `qnnpack` backed for asymmetric - qconfig (returned by get_default_qconfig()) is not available yet. - - * This qconfig uses signed activations and weights. Weights have added - restrictions such as zero point is forced to be 0, making the weights - symmetric, hence the name. And the 8-bit quantized values are - restricting to to [-127, +127], excluding -128. - - * xnnpack has a requantization scale value restriction, 0x1p-32 <= - requantization_scale < 256.0 where, `requantization_scale = (input_scale - * kernel_scale) / (output_scale)`. Using this eps (w/ assumed max value - of 256) is to prevent requantization_scale to go below xnnpack lower - threshold. -""" -default_symmetric_qnnpack_qconfig = QConfig( - activation=HistogramObserver.with_args( - dtype=torch.qint8, reduce_range=False, eps=2**-12 - ), - weight=weight_observer_range_neg_127_to_127, -) - -default_per_channel_symmetric_qnnpack_qconfig = QConfig( - activation=HistogramObserver.with_args( - dtype=torch.qint8, reduce_range=False, eps=2**-12 - ), - weight=per_channel_weight_observer_range_neg_127_to_127, -) - -default_embedding_qat_qconfig = QConfig( - activation=NoopObserver.with_args(dtype=torch.float32), - weight=default_embedding_fake_quant, -) - -default_embedding_qat_qconfig_4bit = QConfig( - activation=NoopObserver.with_args(dtype=torch.float32), - weight=default_embedding_fake_quant_4bit, -) - -default_quint8_weight_qconfig = QConfig( - activation=HistogramObserver, weight=MinMaxObserver -) - - -def get_default_qat_qconfig(backend="x86", version=1): - """ - Returns the default QAT qconfig for the specified backend. - - Args: - * `backend` (str): a string representing the target backend. Currently supports - `x86` (default), `fbgemm`, `qnnpack` and `onednn`. - * `version`: version, for backwards compatibility. Can be `None` or `1`. - - Return: - qconfig - """ - supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] - if backend not in supported_backends: - raise AssertionError( - "backend: " - + str(backend) - + f" not supported. backend must be one of {supported_backends}" - ) - - # Histogram observer is too slow for quantization aware training - if version == 0: - if backend == "fbgemm": - qconfig = QConfig( - activation=FakeQuantize.with_args( - observer=MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - reduce_range=True, - ), - weight=default_per_channel_weight_fake_quant, - ) - elif backend == "qnnpack": - qconfig = QConfig( - activation=FakeQuantize.with_args( - observer=MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - reduce_range=False, - ), - weight=default_weight_fake_quant, - ) - elif backend == "onednn": - qconfig = QConfig( - activation=FakeQuantize.with_args( - observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255 - ), - weight=default_per_channel_weight_fake_quant, - ) - elif backend == "x86": - qconfig = QConfig( - activation=FakeQuantize.with_args( - observer=MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - reduce_range=True, - ), - weight=default_per_channel_weight_fake_quant, - ) - else: - qconfig = default_qat_qconfig - # Use the fused observe + fake_quant modules for doing QAT. - elif version == 1: - if backend == "fbgemm": - qconfig = QConfig( - activation=FusedMovingAvgObsFakeQuantize.with_args( - observer=MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - reduce_range=True, - ), - weight=default_fused_per_channel_wt_fake_quant, - ) - elif backend == "qnnpack": - # TODO: make this compatible with xnnpack constraints - qconfig = QConfig( - activation=FusedMovingAvgObsFakeQuantize.with_args( - observer=MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - reduce_range=False, - ), - weight=default_fused_wt_fake_quant, - ) - elif backend == "onednn": - qconfig = QConfig( - activation=FusedMovingAvgObsFakeQuantize.with_args( - observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255 - ), - weight=default_fused_per_channel_wt_fake_quant, - ) - elif backend == "x86": - qconfig = QConfig( - activation=FusedMovingAvgObsFakeQuantize.with_args( - observer=MovingAverageMinMaxObserver, - quant_min=0, - quant_max=255, - reduce_range=True, - ), - weight=default_fused_per_channel_wt_fake_quant, - ) - else: - qconfig = default_qat_qconfig_v2 - else: - raise AssertionError( - "Version number: " - + str(version) - + "in get_default_qat_qconfig is not supported. Version number must be 0 or 1" - ) - - return qconfig - - -""" -Default symmetric QAT qconfig for qnnpack. And its per channel weight variant. -""" -default_symmetric_qnnpack_qat_qconfig = QConfig( - activation=FusedMovingAvgObsFakeQuantize.with_args( - observer=MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - reduce_range=False, - eps=2**-12, - ), - weight=fused_wt_fake_quant_range_neg_127_to_127, -) - -default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig( - activation=FusedMovingAvgObsFakeQuantize.with_args( - observer=MovingAverageMinMaxObserver, - quant_min=-128, - quant_max=127, - dtype=torch.qint8, - reduce_range=False, - eps=2**-12, - ), - weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127, -) - -_default_fp32_placeholder_qconfig = QConfig( - activation=PlaceholderObserver.with_args(dtype=torch.float32), - weight=PlaceholderObserver.with_args(dtype=torch.float32), -) - -_default_quint8_placeholder_qconfig = QConfig( - activation=PlaceholderObserver.with_args(dtype=torch.quint8), - # operators using this qconfig doesn't have weights - weight=None, -) - - -@deprecated( - "`torch.ao.quantization.get_default_qconfig_dict` is deprecated and will be removed in " - "a future version. Please use `torch.ao.quantization.get_default_qconfig_mapping` instead.", - category=FutureWarning, -) -def get_default_qconfig_dict(backend="x86", version=0): - return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict() - - -@deprecated( - "`torch.ao.quantization.get_default_qat_qconfig_dict` is deprecated and will be removed in " - "a future version. Please use `torch.ao.quantization.get_default_qat_qconfig_mapping` instead.", - category=FutureWarning, -) -def get_default_qat_qconfig_dict(backend="x86", version=1): - return torch.ao.quantization.get_default_qat_qconfig_mapping( - backend, version - ).to_dict() - - -def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> None: - """ - Verifies that this `qconfig` is valid. - """ - if qconfig is None: - return - is_conv_transpose_mod = isinstance( - mod, - (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d), - ) - if is_conv_transpose_mod: - if qconfig.weight is None: - # for now, we assume that any qconfig for ConvTranspose without a weight is valid - return - example_observer = qconfig.weight() - is_per_channel = isinstance( - example_observer, - ( - torch.ao.quantization.PerChannelMinMaxObserver, - torch.ao.quantization.MovingAveragePerChannelMinMaxObserver, - ), - ) - assert ( - not is_per_channel - ), "Per channel weight observer is not supported yet for ConvTranspose{n}d." - - -QConfigAny = Optional[QConfig] -QConfigAny.__module__ = "torch.ao.quantization.qconfig" - - -def _add_module_to_qconfig_obs_ctr( - qconfig: QConfigAny, module: Optional[nn.Module] -) -> Any: - r"""This is a helper function for use in quantization prepare that updates a qconfig so that - the constructors stored in the qconfig will create observers on the same device that - 'module' is on. This is intended to be used when the qconfigs are propagated to each - module in order to avoid potential device alignment issues. - - Args: - qconfig: QConfig with obs constructors stored in activation and weight - module: module which the qconfig is related to - - Return: - qconfig: configured so that obs constructors set to construct on the same device as module - """ - - if module is None or qconfig is None or qconfig._fields != ("activation", "weight"): - return qconfig - - def get_factory_kwargs_based_on_module_device(): - assert isinstance(module, torch.nn.Module) - devices = {p.device for p in module.parameters()} | { - p.device for p in module.buffers() - } - device = next(iter(devices)) if len(devices) > 0 else None - return None if device is None else {"device": device} - - def configure_constructor_to_put_obs_on_module_device(original_constructor): - try: - # check if constructor can accept factory_kwargs - check = original_constructor.with_args(factory_kwargs=None) - check() - return original_constructor.with_callable_args( - factory_kwargs=get_factory_kwargs_based_on_module_device - ) - except AttributeError: # qconfig doesn't have activation or weight - return original_constructor - except TypeError: # the class doesn't accept factory_kwargs argument - return original_constructor - - activation = configure_constructor_to_put_obs_on_module_device(qconfig.activation) - weight = configure_constructor_to_put_obs_on_module_device(qconfig.weight) - - return QConfig(activation, weight) - - -_ObserverOrFakeQuantizeConstructor = Union[ - _PartialWrapper, type[ObserverBase], type[FakeQuantizeBase] -] - - -def _obs_or_fq_ctr_equals( - obs_or_fq1: _ObserverOrFakeQuantizeConstructor, - obs_or_fq2: _ObserverOrFakeQuantizeConstructor, -): - if isinstance(obs_or_fq1, _PartialWrapper) and isinstance( - obs_or_fq2, _PartialWrapper - ): - return _partial_wrapper_equals(obs_or_fq1, obs_or_fq2) - return obs_or_fq1 == obs_or_fq2 - - -def _partial_wrapper_equals(obs_or_fq1: _PartialWrapper, obs_or_fq2: _PartialWrapper): - """ - Return whether the two partial wrappers are equal, - """ - # functools.partial has no __eq__ operator defined so '==' defaults to 'is' - obs_or_fq1_keywords = copy.copy(obs_or_fq1.p.keywords) - obs_or_fq2_keywords = copy.copy(obs_or_fq2.p.keywords) - keywords_equal = True - # compare observer constructor with _obs_or_fq_ctr_equals since direct compare would fail - if "observer" in obs_or_fq1_keywords and "observer" in obs_or_fq2_keywords: - keywords_equal = keywords_equal and _obs_or_fq_ctr_equals( - obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"] - ) - obs_or_fq1_keywords.pop("observer") - obs_or_fq2_keywords.pop("observer") - keywords_equal = keywords_equal and obs_or_fq1_keywords == obs_or_fq2_keywords - return ( - obs_or_fq1.p.func == obs_or_fq2.p.func - and obs_or_fq1.p.args == obs_or_fq2.p.args - and keywords_equal - ) - - -def qconfig_equals(q1: QConfigAny, q2: QConfigAny): - """ - Returns `True` if `q1` equals `q2`, and `False` otherwise. - """ - if q1 is None or q2 is None: - return q1 == q2 - else: - assert q1 is not None and q2 is not None - try: - # Qconfig weight and activation can be either a partial wrapper, - # or an observer class. Special handling is required (above) for - # comparing partial wrappers. - activation_same = _obs_or_fq_ctr_equals(q1.activation, q2.activation) - weight_same = _obs_or_fq_ctr_equals(q1.weight, q2.weight) - return activation_same and weight_same - except AttributeError: - return q1 == q2 - - -def _activation_is_memoryless(qconfig: QConfig): - """ - Return whether the observer for activations defined in the given QConfig is memoryless. - This means a MovingAverage observer with averaging constant equal to 1. - """ - - def _is_memoryless(observer): - return ( - hasattr(observer, "averaging_constant") and observer.averaging_constant == 1 - ) - - act = qconfig.activation() - if isinstance(act, FakeQuantizeBase) and hasattr(act, "activation_post_process"): - return _is_memoryless(act.activation_post_process) - else: - return _is_memoryless(act) - - -def _is_reuse_input_qconfig(qconfig: Optional[QConfig]): - return ( - qconfig is not None - and isinstance(qconfig.activation(), ReuseInputObserver) - and isinstance(qconfig.weight(), NoopObserver) - )