2020
2121
2222class Equal (torch .nn .Module ):
23- aten_op_BI = "torch.ops.aten.eq.Tensor"
24- aten_op_MI = "torch.ops.aten.eq.Scalar"
23+ aten_op_Tensor = "torch.ops.aten.eq.Tensor"
24+ aten_op_Scalar = "torch.ops.aten.eq.Scalar"
2525 exir_op = "executorch_exir_dialects_edge__ops_aten_eq_Tensor"
2626
2727 def __init__ (self , input , other ):
@@ -80,7 +80,7 @@ def get_inputs(self):
8080@common .parametrize ("test_module" , test_data_tensor )
8181def test_eq_tensor_tosa_MI (test_module ):
8282 pipeline = TosaPipelineMI [input_t ](
83- test_module , test_module .get_inputs (), Equal .aten_op_BI , Equal .exir_op
83+ test_module , test_module .get_inputs (), Equal .aten_op_Tensor , Equal .exir_op
8484 )
8585 pipeline .run ()
8686
@@ -90,7 +90,7 @@ def test_eq_scalar_tosa_MI(test_module):
9090 pipeline = TosaPipelineMI [input_t ](
9191 test_module ,
9292 test_module .get_inputs (),
93- Equal .aten_op_MI ,
93+ Equal .aten_op_Scalar ,
9494 Equal .exir_op ,
9595 )
9696 pipeline .run ()
@@ -99,7 +99,7 @@ def test_eq_scalar_tosa_MI(test_module):
9999@common .parametrize ("test_module" , test_data_tensor | test_data_scalar )
100100def test_eq_tosa_BI (test_module ):
101101 pipeline = TosaPipelineBI [input_t ](
102- test_module , test_module .get_inputs (), Equal .aten_op_BI , Equal .exir_op
102+ test_module , test_module .get_inputs (), Equal .aten_op_Tensor , Equal .exir_op
103103 )
104104 pipeline .run ()
105105
@@ -135,15 +135,17 @@ def test_eq_scalar_u55_BI(test_module):
135135 "test_module" ,
136136 test_data_tensor | test_data_scalar ,
137137 xfails = {
138- "eq_tensor_rank4_randn" : "4D fails because boolean Tensors can't be subtracted" ,
138+ "eq_tensor_rank4_randn" : "MLETORCH-847: Boolean eq result unstable on U85" ,
139+ "eq_scalar_rank4_randn" : "MLETORCH-847: Boolean eq result unstable on U85" ,
139140 },
141+ strict = False ,
140142)
141143@common .XfailIfNoCorstone320
142144def test_eq_u85_BI (test_module ):
143145 pipeline = EthosU85PipelineBI [input_t ](
144146 test_module ,
145147 test_module .get_inputs (),
146- Equal .aten_op_BI ,
148+ Equal .aten_op_Tensor ,
147149 Equal .exir_op ,
148150 run_on_fvp = True ,
149151 )
0 commit comments