Skip to content

Commit f1f82cb

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][pt2e] Rename _pt2e to pt2e (pytorch#104668)
Summary: Pull Request resolved: pytorch#104668 X-link: pytorch/executorch#3 att Test Plan: Imported from OSS Reviewed By: andrewor14 Differential Revision: D47202807 fbshipit-source-id: 75c10c2443bfee2aa4061632d63edeac6c48421a
1 parent eb03af4 commit f1f82cb

24 files changed

+98
-65
lines changed

docs/source/quantization-support.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ This module contains a few CustomConfig classes that's used in both eager mode a
120120
ConvertCustomConfig
121121
StandaloneModuleConfigEntry
122122

123+
torch.ao.quantization.pt2e (quantization in pytorch 2.0 export)
124+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
125+
126+
.. automodule:: torch.ao.quantization.pt2e
127+
.. automodule:: torch.ao.quantization.pt2e.quantizer
128+
.. automodule:: torch.ao.quantization.pt2e.representation
129+
123130
torch (quantization related functions)
124131
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
125132

test/inductor/test_inductor_freezing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
import torch
1313

1414
import torch._dynamo as torchdynamo
15-
import torch.ao.quantization._pt2e.quantizer.x86_inductor_quantizer as xiq
15+
import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
1616
from torch import nn
1717
from torch._inductor import config
1818
from torch._inductor.compile_fx import compile_fx
1919
from torch._inductor.utils import override_lowering, run_and_get_code
20-
from torch.ao.quantization._pt2e.quantizer import X86InductorQuantizer
2120
from torch.ao.quantization._quantize_pt2e import convert_pt2e, prepare_pt2e_quantizer
21+
from torch.ao.quantization.pt2e.quantizer import X86InductorQuantizer
2222
from torch.testing import FileCheck
2323
from torch.testing._internal.common_quantization import (
2424
skipIfNoDynamoSupport,

test/quantization/pt2e/test_graph_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import torch._dynamo as torchdynamo
77

8-
from torch.ao.quantization._pt2e.graph_utils import (
8+
from torch.ao.quantization.pt2e.graph_utils import (
99
find_sequential_partitions,
1010
get_equivalent_types,
1111
update_equivalent_types_dict,

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
ObserverOrFakeQuantize,
1616
QConfigMapping,
1717
)
18-
from torch.ao.quantization._pt2e.quantizer import (
18+
from torch.ao.quantization.pt2e.quantizer import (
1919
ComposableQuantizer,
2020
DerivedQuantizationSpec,
2121
EmbeddingQuantizer,
@@ -27,10 +27,10 @@
2727
Quantizer,
2828
SharedQuantizationSpec,
2929
)
30-
from torch.ao.quantization._pt2e.quantizer.composable_quantizer import ( # noqa: F811
30+
from torch.ao.quantization.pt2e.quantizer.composable_quantizer import ( # noqa: F811
3131
ComposableQuantizer,
3232
)
33-
from torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer import (
33+
from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import (
3434
get_symmetric_quantization_config,
3535
)
3636
from torch.ao.quantization._quantize_pt2e import (
@@ -1774,7 +1774,7 @@ def __init__(self):
17741774
def forward(self, x, y):
17751775
return x + y
17761776

1777-
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
1777+
import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq
17781778

17791779
quantizer = QNNPackQuantizer()
17801780
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
@@ -1799,7 +1799,7 @@ def __init__(self):
17991799
def forward(self, x, y):
18001800
return x + y
18011801

1802-
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
1802+
import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq
18031803

18041804
quantizer = QNNPackQuantizer()
18051805
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)

test/quantization/pt2e/test_x86inductor_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import torch._dynamo as torchdynamo
55
import torch.nn as nn
6-
from torch.ao.quantization._pt2e.quantizer import (
6+
from torch.ao.quantization.pt2e.quantizer import (
77
X86InductorQuantizer,
88
)
99
from torch.ao.quantization._quantize_pt2e import (
@@ -19,7 +19,7 @@
1919
from torch.testing._internal.common_quantized import override_quantized_engine
2020
from enum import Enum
2121
import itertools
22-
import torch.ao.quantization._pt2e.quantizer.x86_inductor_quantizer as xiq
22+
import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
2323
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle
2424

2525

torch/_dynamo/skipfiles.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ def _module_dir(m: types.ModuleType):
139139
# TODO: find a better way to express this path without having to import
140140
# `torch.ao.quantization._pt2e`, which interferes with memory profiling
141141
FILENAME_ALLOWLIST |= {
142-
_module_dir(torch) + "ao/quantization/_pt2e/qat_utils.py",
143-
_module_dir(torch) + "ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py",
144-
_module_dir(torch) + "ao/quantization/_pt2e/representation/rewrite.py",
142+
_module_dir(torch) + "ao/quantization/pt2e/qat_utils.py",
143+
_module_dir(torch) + "ao/quantization/pt2e/quantizer/qnnpack_quantizer.py",
144+
_module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py",
145145
}
146146

147147
# TODO (zhxchen17) Make exportdb importable here.

torch/_inductor/freezing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch._inductor.compile_fx import fake_tensor_prop
1414
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
1515
from torch._inductor.fx_passes.post_grad import view_to_reshape
16-
from torch.ao.quantization._pt2e.utils import _fuse_conv_bn_
16+
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
1717
from torch.fx.experimental.proxy_tensor import make_fx
1818
from . import config
1919
from .decomposition import select_decomp_table

torch/ao/quantization/_quantize_pt2e.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
from torch.fx import GraphModule
22

3-
from ._pt2e.prepare import prepare
4-
from ._pt2e._propagate_annotation import propagate_annotation
5-
from ._pt2e.qat_utils import (
3+
from .pt2e.prepare import prepare
4+
from .pt2e._propagate_annotation import propagate_annotation
5+
from .pt2e.qat_utils import (
66
_fuse_conv_bn_qat,
77
_fold_conv_bn_qat,
88
)
9-
from ._pt2e.utils import (
9+
from .pt2e.utils import (
1010
_get_node_name_to_scope,
1111
_fuse_conv_bn_,
1212
_rearrange_weight_observer_for_decomposed_linear,
1313
)
14-
from ._pt2e.representation import reference_representation_rewrite
14+
from .pt2e.representation import reference_representation_rewrite
1515
from .fx.prepare import prepare as fx_prepare
1616
from .quantize_fx import _convert_to_reference_decomposed_fx
1717
from torch.ao.quantization import QConfigMapping
18-
from torch.ao.quantization._pt2e.quantizer import Quantizer
18+
from torch.ao.quantization.pt2e.quantizer import Quantizer
1919
from torch.ao.quantization.backend_config import BackendConfig
2020

2121
from typing import Any, Tuple

torch/ao/quantization/fx/prepare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106
PrepareCustomConfig,
107107
StandaloneModuleConfigEntry,
108108
)
109-
from torch.ao.quantization._pt2e.quantizer import (
109+
from torch.ao.quantization.pt2e.quantizer import (
110110
EdgeOrNode,
111111
QuantizationSpec,
112112
FixedQParamsQuantizationSpec,

torch/ao/quantization/_pt2e/_propagate_annotation.py renamed to torch/ao/quantization/pt2e/_propagate_annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Callable
22

33
import torch
4-
from torch.ao.quantization._pt2e.quantizer import (
4+
from torch.ao.quantization.pt2e.quantizer import (
55
QuantizationAnnotation,
66
SharedQuantizationSpec,
77
)

torch/ao/quantization/_pt2e/graph_utils.py renamed to torch/ao/quantization/pt2e/graph_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
SourcePartition,
1111
)
1212

13+
__all__ = [
14+
"find_sequential_partitions",
15+
"get_equivalent_types",
16+
"update_equivalent_types_dict",
17+
]
18+
1319
_EQUIVALENT_TYPES: List[Set] = [
1420
{torch.nn.Conv2d, torch.nn.functional.conv2d},
1521
{torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},

torch/ao/quantization/_pt2e/prepare.py renamed to torch/ao/quantization/pt2e/prepare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.ao.quantization.qconfig import QConfigAny
2020
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
2121
from typing import Dict, Tuple, Union, Any
22-
from torch.ao.quantization._pt2e.quantizer import (
22+
from torch.ao.quantization.pt2e.quantizer import (
2323
QuantizationAnnotation,
2424
EdgeOrNode,
2525
)

torch/ao/quantization/_pt2e/qat_utils.py renamed to torch/ao/quantization/pt2e/qat_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
SharedQuantizationSpec,
1515
QuantizationSpecBase,
1616
)
17-
from .utils import _fold_bn_weights_into_conv_node
18-
from .utils import _get_aten_graph_module
17+
from .utils import fold_bn_weights_into_conv_node
18+
from .utils import get_aten_graph_module
1919

2020
# Example inputs for `_conv2d_bn_pattern`, `_qat_conv2d_bn_pattern`, and `_qat_conv2d_bn_pattern_no_bias`
2121
_conv2d_bn_pattern_example_inputs = (
@@ -494,15 +494,15 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
494494
m.graph.eliminate_dead_code()
495495
m.recompile()
496496
example_inputs = _conv2d_bn_pattern_example_inputs
497-
match_pattern = _get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
497+
match_pattern = get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
498498

499499
# Step (1): Replace patterns with conv bias
500500
#
501501
# Here we do replacement separately for cases with and without conv bias, since
502502
# the replacement patterns for these two cases are substantially different.
503503
# TODO: use the public replace_pattern API once it also returns replacement nodes
504504

505-
replacement_pattern_with_conv_bias = _get_aten_graph_module(
505+
replacement_pattern_with_conv_bias = get_aten_graph_module(
506506
_qat_conv2d_bn_pattern,
507507
example_inputs,
508508
)
@@ -517,7 +517,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
517517

518518
# Step (2): Replace patterns without conv bias
519519

520-
replacement_pattern_no_conv_bias = _get_aten_graph_module(
520+
replacement_pattern_no_conv_bias = get_aten_graph_module(
521521
_qat_conv2d_bn_pattern_no_conv_bias,
522522
example_inputs,
523523
)
@@ -650,11 +650,11 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
650650
match_pattern = _get_quantized_qat_conv2d_bn_pattern(
651651
is_per_channel, has_relu, has_bias, relu_is_inplace,
652652
)
653-
match_pattern = _get_aten_graph_module(match_pattern, example_inputs, **kwargs)
653+
match_pattern = get_aten_graph_module(match_pattern, example_inputs, **kwargs)
654654
replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern(
655655
is_per_channel, has_relu, has_bias, relu_is_inplace,
656656
)
657-
replacement_pattern = _get_aten_graph_module(replacement_pattern, example_inputs, **kwargs)
657+
replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, **kwargs)
658658
replacements.extend(
659659
replace_pattern_with_filters(
660660
m,
@@ -718,7 +718,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
718718
)
719719

720720
# fold bn weights into conv
721-
_fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
721+
fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
722722

723723
# Copy over literal args for conv
724724
for _, original_node in _filter_nodes_map(r.nodes_map).items():

torch/ao/quantization/_pt2e/quantizer/embedding_quantizer.py renamed to torch/ao/quantization/pt2e/quantizer/embedding_quantizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77
import torch.nn.functional as F
8-
from torch.ao.quantization._pt2e.quantizer.quantizer import (
8+
from torch.ao.quantization.pt2e.quantizer.quantizer import (
99
OperatorConfig,
1010
OperatorPatternType,
1111
QuantizationAnnotation,
@@ -15,6 +15,10 @@
1515
)
1616
from torch.ao.quantization.observer import PerChannelMinMaxObserver
1717

18+
__all__ = [
19+
"get_embedding_operators_config",
20+
"EmbeddingQuantizer",
21+
]
1822

1923
def get_embedding_operators_config() -> OperatorConfig:
2024
weight_quantization_spec = QuantizationSpec(

torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py renamed to torch/ao/quantization/pt2e/quantizer/qnnpack_quantizer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
import torch._dynamo as torchdynamo
1212
import torch.nn.functional as F
1313

14-
from torch.ao.quantization._pt2e.graph_utils import find_sequential_partitions
14+
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
1515

16-
from torch.ao.quantization._pt2e.quantizer.utils import (
16+
from torch.ao.quantization.pt2e.quantizer.utils import (
1717
_annotate_input_qspec_map,
1818
_annotate_output_qspec,
1919
_is_sym_size_node,
@@ -84,7 +84,7 @@ def linear_op(act, weight, bias=None):
8484
return [pattern_w_bias, pattern_wo_bias]
8585

8686

87-
def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
87+
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
8888
supported_operators: Dict[str, List[OperatorPatternType]] = {
8989
# Both conv and linear should be able to handle relu + hardtanh fusion since
9090
# those are clamp ops
@@ -107,15 +107,15 @@ def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternT
107107
return copy.deepcopy(supported_operators)
108108

109109

110-
def get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
110+
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
111111
supported_config_and_operators: List[OperatorConfig] = []
112112
for quantization_config in [
113113
get_symmetric_quantization_config(),
114114
get_symmetric_quantization_config(is_qat=True),
115115
get_symmetric_quantization_config(is_per_channel=True),
116116
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
117117
]:
118-
ops = supported_symmetric_quantized_operators()
118+
ops = _supported_symmetric_quantized_operators()
119119
for op_string, pattern_list in ops.items():
120120
supported_config_and_operators.append(
121121
OperatorConfig(quantization_config, pattern_list)
@@ -205,8 +205,8 @@ def get_symmetric_quantization_config(
205205
return quantization_config
206206

207207

208-
def get_supported_config_and_operators() -> List[OperatorConfig]:
209-
return get_supported_symmetric_config_and_operators()
208+
def _get_supported_config_and_operators() -> List[OperatorConfig]:
209+
return _get_supported_symmetric_config_and_operators()
210210

211211

212212
def _is_annotated(nodes: List[Node]):
@@ -225,7 +225,7 @@ def _is_annotated(nodes: List[Node]):
225225

226226

227227
class QNNPackQuantizer(Quantizer):
228-
supported_config_and_operators = get_supported_config_and_operators()
228+
supported_config_and_operators = _get_supported_config_and_operators()
229229

230230
def __init__(self):
231231
super().__init__()

torch/ao/quantization/_pt2e/quantizer/quantizer.py renamed to torch/ao/quantization/pt2e/quantizer/quantizer.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
"QuantizationSpecBase",
1414
"QuantizationSpec",
1515
"FixedQParamsQuantizationSpec",
16+
"EdgeOrNode",
1617
"SharedQuantizationSpec",
1718
"DerivedQuantizationSpec",
1819
"QuantizationAnnotation",
20+
"QuantizationConfig",
21+
"OperatorConfig",
1922
]
2023

2124
# TODO: maybe remove torch.float32
@@ -84,17 +87,19 @@ class FixedQParamsQuantizationSpec(QuantizationSpecBase):
8487
quant_max: Optional[int] = None
8588
qscheme: Optional[torch.qscheme] = None
8689

90+
"""
91+
The way we refer to other points of quantization in the graph will be either
92+
an input edge or an output value
93+
input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
94+
output value is an fx Node
95+
"""
8796
EdgeOrNode = Union[Tuple[Node, Node], Node]
97+
EdgeOrNode.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer"
8898

8999
@dataclass(eq=True, frozen=True)
90100
class SharedQuantizationSpec(QuantizationSpecBase):
91101
"""
92102
Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
93-
94-
The way we refer to other points of quantization in the graph will be either
95-
an input edge or an output value
96-
input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
97-
output value is an fx Node
98103
"""
99104
edge_or_node: EdgeOrNode
100105

torch/ao/quantization/_pt2e/quantizer/utils.py renamed to torch/ao/quantization/pt2e/quantizer/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
from typing import List
22

33
import torch
4-
from torch.ao.quantization._pt2e.quantizer.quantizer import (
4+
from torch.ao.quantization.pt2e.quantizer.quantizer import (
55
QuantizationAnnotation,
66
QuantizationConfig,
77
QuantizationSpec,
88
)
99
from torch.fx import Node
1010

11+
__all__ = [
12+
"get_input_act_qspec",
13+
"get_output_act_qspec",
14+
"get_weight_qspec",
15+
"get_bias_qspec",
16+
]
1117

1218
def get_input_act_qspec(quantization_config: QuantizationConfig):
1319
if quantization_config is None:

0 commit comments

Comments
 (0)