20
20
21
21
22
22
class Equal (torch .nn .Module ):
23
- aten_op_BI = "torch.ops.aten.eq.Tensor"
24
- aten_op_MI = "torch.ops.aten.eq.Scalar"
23
+ aten_op_Tensor = "torch.ops.aten.eq.Tensor"
24
+ aten_op_Scalar = "torch.ops.aten.eq.Scalar"
25
25
exir_op = "executorch_exir_dialects_edge__ops_aten_eq_Tensor"
26
26
27
27
def __init__ (self , input , other ):
@@ -80,7 +80,7 @@ def get_inputs(self):
80
80
@common .parametrize ("test_module" , test_data_tensor )
81
81
def test_eq_tensor_tosa_MI (test_module ):
82
82
pipeline = TosaPipelineMI [input_t ](
83
- test_module , test_module .get_inputs (), Equal .aten_op_BI , Equal .exir_op
83
+ test_module , test_module .get_inputs (), Equal .aten_op_Tensor , Equal .exir_op
84
84
)
85
85
pipeline .run ()
86
86
@@ -90,7 +90,7 @@ def test_eq_scalar_tosa_MI(test_module):
90
90
pipeline = TosaPipelineMI [input_t ](
91
91
test_module ,
92
92
test_module .get_inputs (),
93
- Equal .aten_op_MI ,
93
+ Equal .aten_op_Scalar ,
94
94
Equal .exir_op ,
95
95
)
96
96
pipeline .run ()
@@ -99,7 +99,7 @@ def test_eq_scalar_tosa_MI(test_module):
99
99
@common .parametrize ("test_module" , test_data_tensor | test_data_scalar )
100
100
def test_eq_tosa_BI (test_module ):
101
101
pipeline = TosaPipelineBI [input_t ](
102
- test_module , test_module .get_inputs (), Equal .aten_op_BI , Equal .exir_op
102
+ test_module , test_module .get_inputs (), Equal .aten_op_Tensor , Equal .exir_op
103
103
)
104
104
pipeline .run ()
105
105
@@ -135,15 +135,17 @@ def test_eq_scalar_u55_BI(test_module):
135
135
"test_module" ,
136
136
test_data_tensor | test_data_scalar ,
137
137
xfails = {
138
- "eq_tensor_rank4_randn" : "4D fails because boolean Tensors can't be subtracted" ,
138
+ "eq_tensor_rank4_randn" : "MLETORCH-847: Boolean eq result unstable on U85" ,
139
+ "eq_scalar_rank4_randn" : "MLETORCH-847: Boolean eq result unstable on U85" ,
139
140
},
141
+ strict = False ,
140
142
)
141
143
@common .XfailIfNoCorstone320
142
144
def test_eq_u85_BI (test_module ):
143
145
pipeline = EthosU85PipelineBI [input_t ](
144
146
test_module ,
145
147
test_module .get_inputs (),
146
- Equal .aten_op_BI ,
148
+ Equal .aten_op_Tensor ,
147
149
Equal .exir_op ,
148
150
run_on_fvp = True ,
149
151
)
0 commit comments