From 153b2588968b8e1518134ef1833163e188f88972 Mon Sep 17 00:00:00 2001 From: maxren Date: Sat, 19 Aug 2023 18:03:01 -0700 Subject: [PATCH 1/2] fix permute copy name Differential Revision: D48488930 fbshipit-source-id: 25ad98483f1903f2671cc4612fae2468319a2766 --- backends/xnnpack/operators/__init__.py | 2 +- .../operators/{op_static_transpose.py => op_permute.py} | 2 +- backends/xnnpack/operators/op_skip_ops.py | 9 --------- 3 files changed, 2 insertions(+), 11 deletions(-) rename backends/xnnpack/operators/{op_static_transpose.py => op_permute.py} (97%) diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index dcffa42ac3f..9d81b7f8e29 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -30,6 +30,7 @@ op_minimum, op_multiply, op_negate, + op_permute, op_prelu, op_quantize_per_tensor, op_relu, @@ -42,7 +43,6 @@ op_squeeze, op_static_constant_pad, op_static_resize_bilinear_2d, - op_static_transpose, op_sub, op_to_copy, ) diff --git a/backends/xnnpack/operators/op_static_transpose.py b/backends/xnnpack/operators/op_permute.py similarity index 97% rename from backends/xnnpack/operators/op_static_transpose.py rename to backends/xnnpack/operators/op_permute.py index ce1cd43c1ad..0ca92a7a039 100644 --- a/backends/xnnpack/operators/op_static_transpose.py +++ b/backends/xnnpack/operators/op_permute.py @@ -20,7 +20,7 @@ @register_node_visitor -class StaticTransposeVisitor(NodeVisitor): +class PermuteVisitor(NodeVisitor): target = "aten.permute_copy.default" def __init__(self, *args) -> None: diff --git a/backends/xnnpack/operators/op_skip_ops.py b/backends/xnnpack/operators/op_skip_ops.py index 83b6eee32b0..345b7896d34 100644 --- a/backends/xnnpack/operators/op_skip_ops.py +++ b/backends/xnnpack/operators/op_skip_ops.py @@ -113,12 +113,3 @@ class OpSymSizeInt(OpSkipOps): """ target = "sym_size.int" - - -@register_node_visitor -class OpPermuteCopyDefault(OpSkipOps): - """ - do nothing if node is permute_copy.default - """ - - target = "aten.permute_copy.default" From e7e6051ab2ca2fc8a2069036c8aa17741fa7e236 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Sat, 19 Aug 2023 18:03:14 -0700 Subject: [PATCH 2/2] Temporarily add aten.op to supported quant modules (#81) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/81 Quant flow for graph capturing is out lined here: https://fb.workplace.com/groups/257735836456307/permalink/545316467698241/ The flow becomes: ``` capture_pre_autograd_graph --> prepare --> convert --> exir.capture ``` As a result, when we capture the converted graphmodule, the source_fn is changed from torch.nn.module to (we are recapturing a graphmodule not a torch.nn.module) I believe someone is currently working on the fix for this, but until then we have to add torch.ops.aten.* to our supported modules Differential Revision: D48488927 fbshipit-source-id: 59e0797170b6030af123ad79e63de96d8a853df2 --- backends/xnnpack/partition/configs.py | 4 ++++ backends/xnnpack/passes/convert_to_linear.py | 1 + 2 files changed, 5 insertions(+) diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index a3b66d3fcb4..b2e6dbc1c77 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -111,6 +111,10 @@ torch.nn.functional.leaky_relu, torch.nn.functional.leaky_relu_, torch.nn.LeakyReLU, + # TODO(): In quant --> export flow source_fn is operator target instead of module name + # This is actively being fixed, but until, we add these operator target names to partitioenr + torch.ops.aten.convolution.default, + torch.ops.aten.addmm.default, ] SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET = set(SUPPORTED_QUANT_MODULES) diff --git a/backends/xnnpack/passes/convert_to_linear.py b/backends/xnnpack/passes/convert_to_linear.py index e58971589f3..1d79da4493a 100644 --- a/backends/xnnpack/passes/convert_to_linear.py +++ b/backends/xnnpack/passes/convert_to_linear.py @@ -27,6 +27,7 @@ class ConvertToLinearPass(ExportPass): linear_modules = [ torch.nn.Linear, torch.nn.functional.linear, + torch.ops.aten.addmm.default, ] targets = [