From 0a2001703cf27910a9449e10d96132f340916341 Mon Sep 17 00:00:00 2001 From: maxren Date: Sat, 19 Aug 2023 18:05:11 -0700 Subject: [PATCH 1/3] fix permute copy name Differential Revision: D48488930 fbshipit-source-id: 41fdd20b120b6056a1beb00c0dd78a4893eed1d8 --- backends/xnnpack/operators/__init__.py | 2 +- .../operators/{op_static_transpose.py => op_permute.py} | 2 +- backends/xnnpack/operators/op_skip_ops.py | 9 --------- 3 files changed, 2 insertions(+), 11 deletions(-) rename backends/xnnpack/operators/{op_static_transpose.py => op_permute.py} (97%) diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index dcffa42ac3f..9d81b7f8e29 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -30,6 +30,7 @@ op_minimum, op_multiply, op_negate, + op_permute, op_prelu, op_quantize_per_tensor, op_relu, @@ -42,7 +43,6 @@ op_squeeze, op_static_constant_pad, op_static_resize_bilinear_2d, - op_static_transpose, op_sub, op_to_copy, ) diff --git a/backends/xnnpack/operators/op_static_transpose.py b/backends/xnnpack/operators/op_permute.py similarity index 97% rename from backends/xnnpack/operators/op_static_transpose.py rename to backends/xnnpack/operators/op_permute.py index ce1cd43c1ad..0ca92a7a039 100644 --- a/backends/xnnpack/operators/op_static_transpose.py +++ b/backends/xnnpack/operators/op_permute.py @@ -20,7 +20,7 @@ @register_node_visitor -class StaticTransposeVisitor(NodeVisitor): +class PermuteVisitor(NodeVisitor): target = "aten.permute_copy.default" def __init__(self, *args) -> None: diff --git a/backends/xnnpack/operators/op_skip_ops.py b/backends/xnnpack/operators/op_skip_ops.py index 83b6eee32b0..345b7896d34 100644 --- a/backends/xnnpack/operators/op_skip_ops.py +++ b/backends/xnnpack/operators/op_skip_ops.py @@ -113,12 +113,3 @@ class OpSymSizeInt(OpSkipOps): """ target = "sym_size.int" - - -@register_node_visitor -class OpPermuteCopyDefault(OpSkipOps): - """ - do nothing if node is permute_copy.default - """ - - target = "aten.permute_copy.default" From ea49ccd2ad12fa869347b7f2da22e0939d87af1f Mon Sep 17 00:00:00 2001 From: maxren Date: Sat, 19 Aug 2023 18:05:11 -0700 Subject: [PATCH 2/3] Temporarily add aten.op to supported quant modules, Commit 1 Differential Revision: https://internalfb.com/D48488927 fbshipit-source-id: 373898025cf15a3313632fe73bd60a7cbeea479c --- backends/xnnpack/partition/configs.py | 4 ++++ backends/xnnpack/passes/convert_to_linear.py | 1 + 2 files changed, 5 insertions(+) diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index a3b66d3fcb4..b2e6dbc1c77 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -111,6 +111,10 @@ torch.nn.functional.leaky_relu, torch.nn.functional.leaky_relu_, torch.nn.LeakyReLU, + # TODO(): In quant --> export flow source_fn is operator target instead of module name + # This is actively being fixed, but until, we add these operator target names to partitioenr + torch.ops.aten.convolution.default, + torch.ops.aten.addmm.default, ] SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET = set(SUPPORTED_QUANT_MODULES) diff --git a/backends/xnnpack/passes/convert_to_linear.py b/backends/xnnpack/passes/convert_to_linear.py index e58971589f3..1d79da4493a 100644 --- a/backends/xnnpack/passes/convert_to_linear.py +++ b/backends/xnnpack/passes/convert_to_linear.py @@ -27,6 +27,7 @@ class ConvertToLinearPass(ExportPass): linear_modules = [ torch.nn.Linear, torch.nn.functional.linear, + torch.ops.aten.addmm.default, ] targets = [ From 8137f4666a9a7e53efa82ea245f63fdee50897f8 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Sat, 19 Aug 2023 18:05:32 -0700 Subject: [PATCH 3/3] Fix linear's with permute copy (#79) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/79 There are some issues with permute_copy in both partitioner as well as convert_to_linear pass ### Partitioner: For Quantized Partitions, we fail to pull in the q/dq nodes above permute_copy. ``` get_attr --> q --> dq --> permute_copy --> addm ``` The solution is checking the inputs to the source_partition for permute_copy node, and if it is one of them, then we add it to the partition and check its inputs ### Convert to Linear Pass In the pattern ``` get_attr --> q --> dq --> permute_copy --> addmm ``` We replace the entire source partition with just linear, however we fail to delete the permute_copy because it is an input to the source partition instead of the dq (dq should actually be the input to the linear source partition). The weight given to linear should not be the result of permute copy, but should instead be the input to permute copy. This happens because q and dq are not tagged as part of the linear source partition, so permute_copy becomes the input to the source partition. Differential Revision: D48488931 fbshipit-source-id: 6269660d4c07cb9d3a7c5c020efd8ce753a40fe3 --- backends/xnnpack/partition/xnnpack_partitioner.py | 3 +++ backends/xnnpack/passes/convert_to_linear.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index 59319358993..262a9301980 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -682,6 +682,9 @@ def get_input_deps( # noqa """ nodes = set() for inp in input_nodes: + if inp.target == exir_ops.edge.aten.permute_copy.default: + nodes.add(inp) + inp = cast(torch.fx.Node, inp.args[0]) if inp.target in self._DQ_OPS: # dequant node nodes.add(inp) diff --git a/backends/xnnpack/passes/convert_to_linear.py b/backends/xnnpack/passes/convert_to_linear.py index 1d79da4493a..d1a91bb3048 100644 --- a/backends/xnnpack/passes/convert_to_linear.py +++ b/backends/xnnpack/passes/convert_to_linear.py @@ -120,6 +120,8 @@ def create_linear( src_partition.input_nodes + src_partition.params, # non quant weight can be in params ) + if linear_weight.target == exir_ops.edge.aten.permute_copy.default: + linear_weight = linear_weight.args[0] logger.debug(f"Found weight: {linear_weight} from node {node}") linear_bias = self.find(