Skip to content

Commit 6d92cdd

Browse files
cccclaifacebook-github-bot
authored andcommitted
patch rms norm recompose pass
Differential Revision: D69090807
1 parent 0a936e0 commit 6d92cdd

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

backends/qualcomm/_passes/recompose_rms_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _get_gamma_node(self, output_node):
3434

3535
def call(self, graph_module: torch.fx.GraphModule):
3636
graph = graph_module.graph
37-
partitions = get_source_partitions(graph, [torch.nn.RMSNorm])
37+
partitions = get_source_partitions(graph, [torch.nn.RMSNorm, torch.ops.aten.rms_norm.default])
3838
for _, src_partitions in partitions.items():
3939
for src_partition in src_partitions:
4040
input_len = len(src_partition.input_nodes)

0 commit comments

Comments
 (0)