Skip to content

Commit 5ca981f

Browse files
mcr229facebook-github-bot
authored andcommitted
Fix linear's with permute copy
Summary: There are some issues with permute_copy in both partitioner as well as convert_to_linear pass ### Partitioner: For Quantized Partitions, we fail to pull in the q/dq nodes above permute_copy. ``` get_attr --> q --> dq --> permute_copy --> addm ``` The solution is checking the inputs to the source_partition for permute_copy node, and if it is one of them, then we add it to the partition and check its inputs ### Convert to Linear Pass In the pattern ``` get_attr --> q --> dq --> permute_copy --> addmm ``` We replace the entire source partition with just linear, however we fail to delete the permute_copy because it is an input to the source partition instead of the dq (dq should actually be the input to the linear source partition). The weight given to linear should not be the result of permute copy, but should instead be the input to permute copy. This happens because q and dq are not tagged as part of the linear source partition, so permute_copy becomes the input to the source partition. Differential Revision: D48488931 fbshipit-source-id: a650a334cca2ce2e9da8f04805b519a19cbf1011
1 parent cbedce9 commit 5ca981f

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,9 @@ def get_input_deps( # noqa
682682
"""
683683
nodes = set()
684684
for inp in input_nodes:
685+
if inp.target == exir_ops.edge.aten.permute_copy.default:
686+
nodes.add(inp)
687+
inp = cast(torch.fx.Node, inp.args[0])
685688
if inp.target in self._DQ_OPS:
686689
# dequant node
687690
nodes.add(inp)

backends/xnnpack/passes/convert_to_linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ def create_linear(
120120
src_partition.input_nodes
121121
+ src_partition.params, # non quant weight can be in params
122122
)
123+
if linear_weight.target == exir_ops.edge.aten.permute_copy.default:
124+
linear_weight = linear_weight.args[0]
123125
logger.debug(f"Found weight: {linear_weight} from node {node}")
124126

125127
linear_bias = self.find(

0 commit comments

Comments
 (0)