Skip to content

Commit 29293a2

Browse files
mcr229facebook-github-bot
authored andcommitted
Fix linear's with permute copy
Differential Revision: https://internalfb.com/D48488931 fbshipit-source-id: 2229a16df1f5b2cdc21a1b5ef4fa0c86a12d2c3a
1 parent 6cae36f commit 29293a2

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,9 @@ def get_input_deps( # noqa
682682
"""
683683
nodes = set()
684684
for inp in input_nodes:
685+
if inp.target == exir_ops.edge.aten.permute_copy.default:
686+
nodes.add(inp)
687+
inp = cast(torch.fx.Node, inp.args[0])
685688
if inp.target in self._DQ_OPS:
686689
# dequant node
687690
nodes.add(inp)

backends/xnnpack/passes/convert_to_linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ def create_linear(
120120
src_partition.input_nodes
121121
+ src_partition.params, # non quant weight can be in params
122122
)
123+
if linear_weight.target == exir_ops.edge.aten.permute_copy.default:
124+
linear_weight = linear_weight.args[0]
123125
logger.debug(f"Found weight: {linear_weight} from node {node}")
124126

125127
linear_bias = self.find(

0 commit comments

Comments
 (0)