Skip to content

Commit cfd1be3

Browse files
authored
Support dynamically quantized 2D convolutions (#10347)
### Summary Add initial support for dynamically quantized Conv2d in XNNPACK: - Add `conv` to `DYNAMIC_OPS` for annotation - Update partitioner to support dynamically quantized Conv2d - Add checks to ensure only 2D, non-depthwise dynamically quantized convs are partitioned and annotated - Update NHWC permute node insertion to trace back to original input for dynamically quantized inputs - Compute `num_nonbatch_dims` based on whether the node feeds into a conv - Remove the `num_nonbatch_dims` check from XNNCompiler - Add unit tests for channels-last permute and single, sequential, and parallel dynamically quantized 2D convs Fixes #9021 ### Test plan ```bash python -m unittest backends.xnnpack.test.ops.test_conv2d.TestConv2d.test_dq_conv2d python -m unittest backends.xnnpack.test.ops.test_conv2d.TestConv2d.test_dq_conv2d_seq python -m unittest backends.xnnpack.test.ops.test_conv2d.TestConv2d.test_dq_conv2d_parallel python -m unittest backends.xnnpack.test.passes.test_channels_last_tagged_reshape.TestChannelsLastTaggedReshapePass.test_dq_conv2d_channels_last_tagged_reshape_pass
1 parent b52ad91 commit cfd1be3

File tree

9 files changed

+235
-14
lines changed

9 files changed

+235
-14
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
11+
from executorch.backends.xnnpack.utils.quant_utils import is_dynamic_qdq
1112
from executorch.backends.xnnpack.utils.utils import is_param_node
1213
from executorch.exir.dialects._ops import ops as exir_ops
1314
from executorch.exir.pass_base import PassResult
@@ -283,14 +284,26 @@ def input_to_nhwc(
283284
]
284285
else:
285286
# Need to create NHWC node
287+
# Check if input uses dynamic quantization
288+
is_dynamic_input = is_dynamic_qdq(input_node)
289+
290+
if is_dynamic_input:
291+
# Trace back to original source node
292+
while getattr(input_node, "args", None):
293+
input_node = input_node.args[0]
294+
286295
with graph_module.graph.inserting_after(input_node):
287296
input_node_nhwc = self.create_call_function_node(
288297
graph_module=graph_module,
289298
target=exir_ops.edge.aten._to_copy.default,
290299
args=(input_node,),
291300
memory_format=torch.channels_last,
292301
)
293-
self.mark_as_nhwc_node(input_node_nhwc)
302+
303+
if is_dynamic_input:
304+
# Replace downstream input_nodes with NHWC node
305+
input_node.replace_all_uses_with(input_node_nhwc)
306+
input_node_nhwc.args = (input_node,)
294307

295308
self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
296309
graph_module=graph_module,

backends/xnnpack/operators/quant_params.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,27 @@ def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
141141
tensor, self.scale, self.zp, self.qmin, self.qmax, self.dtype
142142
)
143143

144+
# Temporary helper until non-batch dimensions can be inferred
145+
# Detects if a node feeds into a conv op by checking all downstream users
146+
@staticmethod
147+
def _feeds_into_conv(node: torch.fx.Node) -> bool:
148+
users_list = [node]
149+
150+
while users_list:
151+
current_user = users_list.pop()
152+
if "convolution" in str(current_user.target):
153+
return True
154+
users_list.extend(current_user.users)
155+
156+
return False
157+
144158
@classmethod
145159
def _from_dynamic_input_node(cls, quant_node: torch.fx.Node) -> QuantParams:
146160
q_input = quant_node.args[0] # fp32 input
147161
assert isinstance(q_input, torch.fx.Node)
148162
# TODO - materialize this from the quant_node scale count and val shape
149-
num_nonbatch_dims = 1
163+
# Set non-batch dims to 3 if node feeds into conv (only 2D is supported), otherwise set to 1 for linear
164+
num_nonbatch_dims = 3 if cls._feeds_into_conv(quant_node) else 1
150165

151166
return cls(
152167
per_channel=False, # True is not valid

backends/xnnpack/partition/config/gemm_configs.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import cast, List, Optional, Tuple
1010

1111
import torch
12+
from executorch.backends.transforms import get_shape
1213
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1314
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1415
ConfigPrecisionType,
@@ -27,6 +28,7 @@
2728
)
2829
from executorch.backends.xnnpack.utils.utils import (
2930
get_input_node,
31+
is_depthwise_conv,
3032
is_getitem,
3133
is_node,
3234
is_param_node,
@@ -359,12 +361,23 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
359361
return False # Only support 1D + 2D Conv
360362

361363
kernel_node = get_input_node(node, 1)
364+
kernel_shape = get_shape(kernel_node)
362365
weight_quant_params = QuantParams.from_weights(kernel_node, ep)
363-
364-
is_transpose = node.args[6]
365366
groups = cast(int, node.args[8])
367+
is_transpose = node.args[6]
368+
369+
# XNNPACK does not support dynamic quantization convs that are not 2D or are depthwise
370+
if self._detect_precision(node) == ConfigPrecisionType.DYNAMIC_QUANT and (
371+
len(conv_stride) != 2
372+
or is_depthwise_conv(kernel_shape, groups, is_transpose)
373+
):
374+
why(
375+
node,
376+
"XNNPACK only supports standard 2D convolutions for dynamic quantization",
377+
)
378+
return False
366379

367-
# XNNPack does not support non-zero output padding in transposed
380+
# XNNPACK does not support non-zero output padding in transposed
368381
# convolutions.
369382
if is_transpose and any(
370383
out_pad != 0 for out_pad in cast(List[int], node.args[7])
@@ -394,6 +407,7 @@ def supported_precision_types(self):
394407
return [
395408
ConfigPrecisionType.FP32,
396409
ConfigPrecisionType.STATIC_QUANT,
410+
ConfigPrecisionType.DYNAMIC_QUANT,
397411
]
398412

399413

backends/xnnpack/quantizer/xnnpack_quantizer.py

+1
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ class XNNPACKQuantizer(Quantizer):
265265

266266
DYNAMIC_OPS = [
267267
"linear",
268+
"conv",
268269
]
269270

270271
def __init__(self) -> None:

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
import torch.nn.functional as F
9+
from executorch.backends.xnnpack.utils.utils import is_depthwise_conv
910
from torch._subclasses import FakeTensor
1011
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
1112
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
@@ -29,7 +30,6 @@
2930
)
3031
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
3132

32-
3333
__all__ = [
3434
"OperatorConfig",
3535
"OperatorPatternType",
@@ -323,6 +323,23 @@ def _do_annotate_conv(
323323
assert isinstance(weight, Node)
324324
input_qspec_map[weight] = get_weight_qspec(quantization_config)
325325

326+
# Only annotate dynamically quantized conv if it's 2D and not depthwise
327+
if (
328+
quantization_config
329+
and quantization_config.input_activation
330+
and quantization_config.input_activation.is_dynamic
331+
):
332+
weight_val = weight.meta.get("val", None)
333+
weight_shape = getattr(weight_val, "shape", None)
334+
335+
# Skip if not a 4D weight tensor (i.e. not conv2d)
336+
if weight_shape is not None and len(weight_shape) != 4:
337+
continue
338+
339+
# Skip if depthwise (default to groups=1 since it's not an arg)
340+
if is_depthwise_conv(weight_shape, 1, is_conv_transpose):
341+
continue
342+
326343
# adding weight node to the partition as well
327344
partition = [conv_node, conv_node.args[1]]
328345

backends/xnnpack/runtime/XNNCompiler.cpp

+1-6
Original file line numberDiff line numberDiff line change
@@ -512,11 +512,6 @@ Error defineTensor(
512512
buffer_ptr == nullptr,
513513
Internal,
514514
"Dynamically quantized tensor should not have constant data but found non-nullptr");
515-
// TODO(T179441835): Dynamic Quantization with num_nonbatch_dims > 1
516-
ET_CHECK_OR_RETURN_ERROR(
517-
qparams->num_nonbatch_dims() == 1,
518-
Internal,
519-
"Dynamically Quantized Tensors currently only support per token quantization");
520515
status = xnn_define_dynamically_quantized_tensor_value(
521516
/*subgraph=*/subgraph_ptr,
522517
/*datatype=*/getDataType(tensor_value->datatype()),
@@ -1172,7 +1167,7 @@ Error defineStaticTransposeNode(
11721167
ET_CHECK_OR_RETURN_ERROR(
11731168
status == xnn_status_success,
11741169
Internal,
1175-
"Failed to create sigmoid node %i with code: %s",
1170+
"Failed to create static transpose node %i with code: %s",
11761171
node->debug_handle(),
11771172
xnn_status_to_string(status));
11781173

backends/xnnpack/test/ops/test_conv2d.py

+96-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
except:
1919
has_quantized_ops = False
2020

21+
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
22+
ConfigPrecisionType,
23+
)
24+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
2125
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
2226
get_symmetric_quantization_config,
2327
)
@@ -26,7 +30,7 @@
2630
)
2731
from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn
2832
from executorch.backends.xnnpack.test.tester import Quantize, Tester
29-
33+
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
3034
from executorch.exir.dialects._ops import ops as exir_ops
3135

3236

@@ -169,6 +173,43 @@ def get_inputs(self):
169173
return (torch.randn(2, 2, 4, 4),)
170174

171175

176+
class Conv2dDQSeq(torch.nn.Module):
177+
def __init__(self):
178+
super().__init__()
179+
self.first = torch.nn.Conv2d(
180+
in_channels=3, out_channels=8, kernel_size=3, padding=1
181+
)
182+
self.second = torch.nn.Conv2d(
183+
in_channels=8, out_channels=10, kernel_size=3, padding=1
184+
)
185+
186+
def forward(self, x):
187+
y = self.first(x)
188+
return self.second(y)
189+
190+
def get_inputs(self):
191+
return (torch.randn(1, 3, 8, 8),)
192+
193+
194+
class Conv2dDQParallel(torch.nn.Module):
195+
def __init__(self):
196+
super().__init__()
197+
self.first = torch.nn.Conv2d(
198+
in_channels=3, out_channels=8, kernel_size=3, padding=1
199+
)
200+
self.second = torch.nn.Conv2d(
201+
in_channels=3, out_channels=8, kernel_size=3, padding=1
202+
)
203+
204+
def forward(self, x):
205+
first = self.first(x)
206+
second = self.second(x)
207+
return first, second
208+
209+
def get_inputs(self):
210+
return (torch.randn(1, 3, 8, 8),)
211+
212+
172213
class TestConv2d(unittest.TestCase):
173214
def setUp(self):
174215
torch._dynamo.reset()
@@ -223,6 +264,37 @@ def _test(
223264
.run_method_and_compare_outputs(qtol=1)
224265
)
225266

267+
def _test_dq(
268+
self,
269+
m: torch.nn.Module,
270+
conv_count=1,
271+
dynamic_shapes=None,
272+
):
273+
quant_config = get_symmetric_quantization_config(
274+
is_per_channel=True,
275+
is_dynamic=True,
276+
)
277+
278+
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
279+
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
280+
per_op_mode=True,
281+
)
282+
283+
tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes)
284+
tester.quantize(Quantize(quantization_config=quant_config))
285+
tester.export()
286+
tester.check(["torch.ops.quantized_decomposed.choose_qparams"])
287+
tester.to_edge_transform_and_lower(
288+
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
289+
)
290+
tester.check_count(
291+
{"torch.ops.higher_order.executorch_call_delegate": conv_count}
292+
)
293+
tester.check_not(["executorch_exir_dialects_edge__ops_aten_conv2d_default"])
294+
tester.to_executorch()
295+
tester.serialize()
296+
tester.run_method_and_compare_outputs(qtol=1)
297+
226298
def test_fp16_conv2d(self) -> None:
227299
for transpose in (True, False):
228300
for has_bias in (True, False):
@@ -699,3 +771,26 @@ def forward(self, x):
699771
.serialize()
700772
.run_method_and_compare_outputs(qtol=1)
701773
)
774+
775+
def test_dq_conv2d(self) -> None:
776+
model = Conv2d(
777+
in_channels=3,
778+
out_channels=10,
779+
kernel_size=(3, 3),
780+
stride=(1, 1),
781+
padding=(0, 0),
782+
batches=1,
783+
width=8,
784+
height=8,
785+
)
786+
self._test_dq(model)
787+
788+
def test_dq_conv2d_seq(self) -> None:
789+
model = Conv2dDQSeq()
790+
conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d)
791+
self._test_dq(model, conv_count)
792+
793+
def test_dq_conv2d_parallel(self) -> None:
794+
model = Conv2dDQParallel()
795+
conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d)
796+
self._test_dq(model, conv_count)

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
1111
ChannelsLastTaggedReshapePass,
1212
)
13+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
14+
get_symmetric_quantization_config,
15+
)
1316
from executorch.backends.xnnpack.test.test_xnnpack_utils_classes import (
1417
OpSequencesAddConv2d,
1518
)
16-
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
19+
from executorch.backends.xnnpack.test.tester import Quantize, RunPasses, Tester
1720

1821

1922
class TestChannelsLastTaggedReshapePass(unittest.TestCase):
@@ -35,6 +38,10 @@ def setUp(self):
3538
dequant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
3639
conv_name = "executorch_exir_dialects_edge__ops_aten_convolution_default"
3740
relu_name = "executorch_exir_dialects_edge__ops_aten_relu_default"
41+
choose_qparams_name = (
42+
"executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_tensor"
43+
)
44+
dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
3845

3946
def test_fp32_channels_last_tagged_reshape_pass(self):
4047
for module, num_reshape in self.modules.items():
@@ -179,3 +186,37 @@ def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self):
179186
)
180187
.run_method_and_compare_outputs()
181188
)
189+
190+
class Conv2dDynamicQuant(torch.nn.Module):
191+
def __init__(self):
192+
super().__init__()
193+
self.conv = torch.nn.Conv2d(3, 10, 3)
194+
195+
def forward(self, x):
196+
return self.conv(x)
197+
198+
def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None:
199+
(
200+
Tester(self.Conv2dDynamicQuant().eval(), (torch.randn(1, 3, 8, 8),))
201+
.quantize(
202+
Quantize(
203+
quantization_config=get_symmetric_quantization_config(
204+
is_dynamic=True
205+
)
206+
)
207+
)
208+
.export()
209+
.to_edge()
210+
.run_passes(self.PassStage)
211+
.check(
212+
[
213+
self.to_copy_name,
214+
self.choose_qparams_name,
215+
self.dynamic_quant_name,
216+
self.dequant_name,
217+
self.conv_name,
218+
self.to_copy_name,
219+
]
220+
)
221+
.run_method_and_compare_outputs()
222+
)

0 commit comments

Comments
 (0)