|
18 | 18 | except:
|
19 | 19 | has_quantized_ops = False
|
20 | 20 |
|
| 21 | +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( |
| 22 | + ConfigPrecisionType, |
| 23 | +) |
| 24 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
21 | 25 | from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
|
22 | 26 | get_symmetric_quantization_config,
|
23 | 27 | )
|
|
26 | 30 | )
|
27 | 31 | from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn
|
28 | 32 | from executorch.backends.xnnpack.test.tester import Quantize, Tester
|
29 |
| - |
| 33 | +from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower |
30 | 34 | from executorch.exir.dialects._ops import ops as exir_ops
|
31 | 35 |
|
32 | 36 |
|
@@ -169,6 +173,43 @@ def get_inputs(self):
|
169 | 173 | return (torch.randn(2, 2, 4, 4),)
|
170 | 174 |
|
171 | 175 |
|
| 176 | +class Conv2dDQSeq(torch.nn.Module): |
| 177 | + def __init__(self): |
| 178 | + super().__init__() |
| 179 | + self.first = torch.nn.Conv2d( |
| 180 | + in_channels=3, out_channels=8, kernel_size=3, padding=1 |
| 181 | + ) |
| 182 | + self.second = torch.nn.Conv2d( |
| 183 | + in_channels=8, out_channels=10, kernel_size=3, padding=1 |
| 184 | + ) |
| 185 | + |
| 186 | + def forward(self, x): |
| 187 | + y = self.first(x) |
| 188 | + return self.second(y) |
| 189 | + |
| 190 | + def get_inputs(self): |
| 191 | + return (torch.randn(1, 3, 8, 8),) |
| 192 | + |
| 193 | + |
| 194 | +class Conv2dDQParallel(torch.nn.Module): |
| 195 | + def __init__(self): |
| 196 | + super().__init__() |
| 197 | + self.first = torch.nn.Conv2d( |
| 198 | + in_channels=3, out_channels=8, kernel_size=3, padding=1 |
| 199 | + ) |
| 200 | + self.second = torch.nn.Conv2d( |
| 201 | + in_channels=3, out_channels=8, kernel_size=3, padding=1 |
| 202 | + ) |
| 203 | + |
| 204 | + def forward(self, x): |
| 205 | + first = self.first(x) |
| 206 | + second = self.second(x) |
| 207 | + return first, second |
| 208 | + |
| 209 | + def get_inputs(self): |
| 210 | + return (torch.randn(1, 3, 8, 8),) |
| 211 | + |
| 212 | + |
172 | 213 | class TestConv2d(unittest.TestCase):
|
173 | 214 | def setUp(self):
|
174 | 215 | torch._dynamo.reset()
|
@@ -223,6 +264,37 @@ def _test(
|
223 | 264 | .run_method_and_compare_outputs(qtol=1)
|
224 | 265 | )
|
225 | 266 |
|
| 267 | + def _test_dq( |
| 268 | + self, |
| 269 | + m: torch.nn.Module, |
| 270 | + conv_count=1, |
| 271 | + dynamic_shapes=None, |
| 272 | + ): |
| 273 | + quant_config = get_symmetric_quantization_config( |
| 274 | + is_per_channel=True, |
| 275 | + is_dynamic=True, |
| 276 | + ) |
| 277 | + |
| 278 | + DynamicallyQuantizedPartitioner = XnnpackPartitioner( |
| 279 | + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, |
| 280 | + per_op_mode=True, |
| 281 | + ) |
| 282 | + |
| 283 | + tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes) |
| 284 | + tester.quantize(Quantize(quantization_config=quant_config)) |
| 285 | + tester.export() |
| 286 | + tester.check(["torch.ops.quantized_decomposed.choose_qparams"]) |
| 287 | + tester.to_edge_transform_and_lower( |
| 288 | + ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner]) |
| 289 | + ) |
| 290 | + tester.check_count( |
| 291 | + {"torch.ops.higher_order.executorch_call_delegate": conv_count} |
| 292 | + ) |
| 293 | + tester.check_not(["executorch_exir_dialects_edge__ops_aten_conv2d_default"]) |
| 294 | + tester.to_executorch() |
| 295 | + tester.serialize() |
| 296 | + tester.run_method_and_compare_outputs(qtol=1) |
| 297 | + |
226 | 298 | def test_fp16_conv2d(self) -> None:
|
227 | 299 | for transpose in (True, False):
|
228 | 300 | for has_bias in (True, False):
|
@@ -699,3 +771,26 @@ def forward(self, x):
|
699 | 771 | .serialize()
|
700 | 772 | .run_method_and_compare_outputs(qtol=1)
|
701 | 773 | )
|
| 774 | + |
| 775 | + def test_dq_conv2d(self) -> None: |
| 776 | + model = Conv2d( |
| 777 | + in_channels=3, |
| 778 | + out_channels=10, |
| 779 | + kernel_size=(3, 3), |
| 780 | + stride=(1, 1), |
| 781 | + padding=(0, 0), |
| 782 | + batches=1, |
| 783 | + width=8, |
| 784 | + height=8, |
| 785 | + ) |
| 786 | + self._test_dq(model) |
| 787 | + |
| 788 | + def test_dq_conv2d_seq(self) -> None: |
| 789 | + model = Conv2dDQSeq() |
| 790 | + conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d) |
| 791 | + self._test_dq(model, conv_count) |
| 792 | + |
| 793 | + def test_dq_conv2d_parallel(self) -> None: |
| 794 | + model = Conv2dDQParallel() |
| 795 | + conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d) |
| 796 | + self._test_dq(model, conv_count) |
0 commit comments