10
10
11
11
import torch
12
12
import torch .fx
13
- from executorch .backends .arm ._passes .arm_pass_utils import create_node
13
+ from executorch .backends .arm ._passes .arm_pass_utils import (
14
+ create_node ,
15
+ get_node_arg ,
16
+ set_node_arg ,
17
+ )
14
18
from executorch .exir .dialects ._ops import ops as exir_ops
15
19
from executorch .exir .pass_base import ExportPass , PassResult
16
20
17
21
18
- class InsertSqueezeAfterSumPass (ExportPass ):
22
+ class KeepDimsFalseToSqueezePass (ExportPass ):
19
23
"""
20
- In Pytorch, the default behaviour of Tensor.sum is to squeeze
24
+ In Pytorch, the default behaviour of for example Tensor.sum is to squeeze
21
25
the dimension that is summed (keep_dim = False).
22
26
However, in TOSA, REDUCE_SUM always preserves the
23
27
rank of the input (keep_dim = True).
@@ -31,28 +35,52 @@ class InsertSqueezeAfterSumPass(ExportPass):
31
35
squeeze(dim = dims)
32
36
"""
33
37
38
+ # CURRENTLY NOT HANDLED OPS
39
+ # exir_ops.edge.aten.amax,
40
+ # exir_ops.edge.aten.amin,
41
+ # exir_ops.edge.aten.any.dim,
42
+ # exir_ops.edge.aten.any.dims,
43
+ # exir_ops.edge.aten.argmax,
44
+ # exir_ops.edge.aten.argmin,
45
+ # exir_ops.edge.aten.max.dim,
46
+ # exir_ops.edge.aten.min.dim,
47
+ # exir_ops.edge.aten.prod.dim_int,
48
+
49
+ # HANDLED OPS
50
+ # exir_ops.edge.aten.sum.dim_IntList
51
+ # exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass)
52
+ # exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass)
53
+ # exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass)
54
+
34
55
def call (self , graph_module : torch .fx .GraphModule ):
35
56
for node in graph_module .graph .nodes :
57
+ keep_dim_index = None
58
+
36
59
if node .op != "call_function" :
37
60
continue
38
- if node .target != exir_ops .edge .aten .sum .dim_IntList :
61
+ if node .target == exir_ops .edge .aten .sum .dim_IntList :
62
+ keep_dim_index = 2
63
+ else :
39
64
continue
65
+
40
66
sum_node = cast (torch .fx .Node , node )
41
- keep_dim = cast (bool , sum_node .args [2 ] if len (sum_node .args ) > 2 else False )
67
+ keep_dim = get_node_arg (sum_node .args , keep_dim_index , False )
68
+
42
69
if keep_dim :
43
70
continue
44
71
45
- dim_list = cast ( list [ int ], sum_node .args [ 1 ])
72
+ dim_list = get_node_arg ( sum_node .args , 1 , [ 0 ])
46
73
47
74
# Add keep_dim = True arg to sum node.
48
- sum_node . args = sum_node . args [ 0 : 2 ] + ( True , )
75
+ set_node_arg ( sum_node , 2 , True )
49
76
50
77
with graph_module .graph .inserting_after (sum_node ):
51
78
squeeze_node = create_node (
52
79
graph_module .graph , exir_ops .edge .aten .squeeze_copy .dims , ()
53
80
)
54
81
sum_node .replace_all_uses_with (squeeze_node )
55
82
squeeze_node .args = (sum_node , dim_list )
83
+
56
84
graph_module .graph .eliminate_dead_code ()
57
85
graph_module .recompile ()
58
86
graph_module = super ().call (graph_module ).graph_module
0 commit comments