1212import torch
1313from executorch .backends .arm .test import common
1414from executorch .backends .arm .test .tester .arm_tester import ArmTester
15+ from parameterized import parameterized
1516
1617logger = logging .getLogger (__name__ )
1718logger .setLevel (logging .INFO )
@@ -126,6 +127,32 @@ def forward(self, x):
126127 return x
127128
128129
130+ class ComboConvRelu6 (torch .nn .Module ):
131+ edge_op_list = [
132+ "executorch_exir_dialects_edge__ops_aten_convolution_default" ,
133+ "executorch_exir_dialects_edge__ops_aten_hardtanh_default" ,
134+ ]
135+
136+ test_data = [
137+ (20 * torch .randn (1 , 3 , 256 , 256 ),),
138+ (5 * torch .randn (1 , 3 , 256 , 256 ),),
139+ (torch .randn (1 , 3 , 256 , 256 ),),
140+ (- 5 * torch .randn (1 , 3 , 256 , 256 ),),
141+ ]
142+
143+ def __init__ (self ):
144+ super ().__init__ ()
145+ self .conv2d = torch .nn .Conv2d (
146+ in_channels = 3 , out_channels = 3 , kernel_size = 3 , stride = 1 , groups = 1
147+ )
148+ self .relu6 = torch .nn .ReLU6 ()
149+
150+ def forward (self , x ):
151+ x = self .conv2d (x )
152+ x = self .relu6 (x )
153+ return x
154+
155+
129156class TestConvCombos (unittest .TestCase ):
130157 def _test_conv_combo_tosa_MI_pipeline (
131158 self , module : torch .nn .Module , test_data : Tuple [torch .Tensor ]
@@ -222,15 +249,9 @@ def test_conv_batchnorm_relu_tosa_MI(self):
222249 model = ComboConvBatchnormRelu ()
223250 self ._test_conv_combo_tosa_MI_pipeline (model , model .get_inputs ())
224251
225- # TODO(MLETORCH-85): Investigate numerical issue. This diff is present in legacy
226- # testcase as well (and also not tested). For now, just increase the
227- # tolerance, such that we don't skip the test entirely (i.e. we maintain
228- # functionality).
229252 def test_conv_batchnorm_relu_tosa_BI (self ):
230253 model = ComboConvBatchnormRelu ()
231- self ._test_conv_combo_tosa_BI_pipeline (
232- model , model .get_inputs (), atol = 1.0 , rtol = 1.0
233- )
254+ self ._test_conv_combo_tosa_BI_pipeline (model , model .get_inputs ())
234255
235256 @unittest .skipIf (
236257 not common .VELA_INSTALLED ,
@@ -240,21 +261,41 @@ def test_conv_batchnorm_relu_u55_BI(self):
240261 model = ComboConvBatchnormRelu ()
241262 self ._test_conv_combo_u55_BI_pipeline (model , model .get_inputs ())
242263
264+ ##################
265+ ## Conv + ReLU6 ##
266+ ##################
267+ @parameterized .expand (ComboConvRelu6 .test_data )
268+ def test_conv_relu6_tosa_MI (self , test_data : torch .Tensor ):
269+ model = ComboConvRelu6 ()
270+ test_data = (test_data ,)
271+ self ._test_conv_combo_tosa_MI_pipeline (model , test_data )
272+
273+ @parameterized .expand (ComboConvRelu6 .test_data )
274+ def test_conv_relu6_tosa_BI (self , test_data : torch .Tensor ):
275+ model = ComboConvRelu6 ()
276+ test_data = (test_data ,)
277+ self ._test_conv_combo_tosa_BI_pipeline (model , test_data )
278+
279+ @parameterized .expand (ComboConvRelu6 .test_data )
280+ @unittest .skipIf (
281+ not common .VELA_INSTALLED ,
282+ "There is no point in running U55 tests if the Vela tool is not installed" ,
283+ )
284+ def test_conv_relu6_u55_BI (self , test_data : torch .Tensor ):
285+ model = ComboConvRelu6 ()
286+ test_data = (test_data ,)
287+ self ._test_conv_combo_u55_BI_pipeline (model , test_data )
288+
243289 ###############################
244290 ## Block bottleneck residual ##
245291 ###############################
246292 def test_block_bottleneck_residual_tosa_MI (self ):
247293 model = ComboBlockBottleneckResidual ()
248294 self ._test_conv_combo_tosa_MI_pipeline (model , model .get_inputs ())
249295
250- # TODO(MLETORCH-85): Investigate numerical issue. This diff was present in legacy
251- # testcase as well. For now, just increase the tolerance, such that
252- # we don't skip the test entirely (i.e. we maintain functionality).
253296 def test_block_bottleneck_residual_tosa_BI (self ):
254297 model = ComboBlockBottleneckResidual ()
255- self ._test_conv_combo_tosa_BI_pipeline (
256- model , model .get_inputs (), atol = 1.0 , rtol = 1.0
257- )
298+ self ._test_conv_combo_tosa_BI_pipeline (model , model .get_inputs ())
258299
259300 @unittest .skipIf (
260301 not common .VELA_INSTALLED ,
0 commit comments