@@ -199,21 +199,31 @@ def resize_bounding_box():
199
199
yield SampleInput (bounding_box , size = size , image_size = bounding_box .image_size )
200
200
201
201
202
- class TestKernelsCommon :
203
- @pytest .mark .parametrize ("functional_info" , FUNCTIONAL_INFOS , ids = lambda functional_info : functional_info .name )
204
- def test_scriptable (self , functional_info ):
205
- jit .script (functional_info .functional )
206
-
207
- @pytest .mark .parametrize (
208
- ("functional_info" , "sample_input" ),
209
- [
210
- pytest .param (functional_info , sample_input , id = f"{ functional_info .name } -{ idx } " )
211
- for functional_info in FUNCTIONAL_INFOS
212
- for idx , sample_input in enumerate (functional_info .sample_inputs ())
213
- ],
214
- )
215
- def test_eager_vs_scripted (self , functional_info , sample_input ):
216
- eager = functional_info (sample_input )
217
- scripted = jit .script (functional_info .functional )(* sample_input .args , ** sample_input .kwargs )
218
-
219
- torch .testing .assert_close (eager , scripted )
202
+ @pytest .mark .parametrize (
203
+ "kernel" ,
204
+ [
205
+ pytest .param (kernel , id = name )
206
+ for name , kernel in F .__dict__ .items ()
207
+ if not name .startswith ("_" )
208
+ and callable (kernel )
209
+ and any (feature_type in name for feature_type in {"image" , "segmentation_mask" , "bounding_box" , "label" })
210
+ and "pil" not in name
211
+ ],
212
+ )
213
+ def test_scriptable (kernel ):
214
+ jit .script (kernel )
215
+
216
+
217
+ @pytest .mark .parametrize (
218
+ ("functional_info" , "sample_input" ),
219
+ [
220
+ pytest .param (functional_info , sample_input , id = f"{ functional_info .name } -{ idx } " )
221
+ for functional_info in FUNCTIONAL_INFOS
222
+ for idx , sample_input in enumerate (functional_info .sample_inputs ())
223
+ ],
224
+ )
225
+ def test_eager_vs_scripted (functional_info , sample_input ):
226
+ eager = functional_info (sample_input )
227
+ scripted = jit .script (functional_info .functional )(* sample_input .args , ** sample_input .kwargs )
228
+
229
+ torch .testing .assert_close (eager , scripted )
0 commit comments