@@ -610,15 +610,6 @@ def test_jit_boxes_list(self):
610
610
self ._helper_jit_boxes_list (model )
611
611
612
612
613
- optests .generate_opcheck_tests (
614
- testcase = TestRoIAlign ,
615
- namespaces = ["torchvision" ],
616
- failures_dict_path = os .path .join (os .path .dirname (__file__ ), "optests_failures_dict.json" ),
617
- additional_decorators = [],
618
- test_utils = OPTESTS ,
619
- )
620
-
621
-
622
613
class TestPSRoIAlign (RoIOpTester ):
623
614
mps_backward_atol = 5e-2
624
615
@@ -676,6 +667,43 @@ def test_boxes_shape(self):
676
667
self ._helper_boxes_shape (ops .ps_roi_align )
677
668
678
669
670
+ @pytest .mark .parametrize (
671
+ "op" ,
672
+ (
673
+ torch .ops .torchvision .roi_pool ,
674
+ torch .ops .torchvision .ps_roi_pool ,
675
+ torch .ops .torchvision .roi_align ,
676
+ torch .ops .torchvision .ps_roi_align ,
677
+ ),
678
+ )
679
+ @pytest .mark .parametrize ("dtype" , (torch .float16 , torch .float32 , torch .float64 ))
680
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
681
+ @pytest .mark .parametrize ("requires_grad" , (True , False ))
682
+ def test_roi_opcheck (op , dtype , device , requires_grad ):
683
+ # This manually calls opcheck() on the roi ops. We do that instead of
684
+ # relying on opcheck.generate_opcheck_tests() as e.g. done for nms, because
685
+ # pytest and generate_opcheck_tests() don't interact very well when it comes
686
+ # to skipping tests - and these ops need to skip the MPS tests since MPS we
687
+ # don't support dynamic shapes yet for MPS.
688
+ rois = torch .tensor (
689
+ [[0 , 0 , 0 , 9 , 9 ], [0 , 0 , 5 , 4 , 9 ], [0 , 5 , 5 , 9 , 9 ], [1 , 0 , 0 , 9 , 9 ]],
690
+ dtype = dtype ,
691
+ device = device ,
692
+ requires_grad = requires_grad ,
693
+ )
694
+ pool_size = 5
695
+ num_channels = 2 * (pool_size ** 2 )
696
+ x = torch .rand (2 , num_channels , 10 , 10 , dtype = dtype , device = device )
697
+
698
+ kwargs = dict (rois = rois , spatial_scale = 1 , pooled_height = pool_size , pooled_width = pool_size )
699
+ if op in (torch .ops .torchvision .roi_align , torch .ops .torchvision .ps_roi_align ):
700
+ kwargs ["sampling_ratio" ] = - 1
701
+ if op is torch .ops .torchvision .roi_align :
702
+ kwargs ["aligned" ] = True
703
+
704
+ optests .opcheck (op , args = (x ,), kwargs = kwargs )
705
+
706
+
679
707
class TestMultiScaleRoIAlign :
680
708
def make_obj (self , fmap_names = None , output_size = (7 , 7 ), sampling_ratio = 2 , wrap = False ):
681
709
if fmap_names is None :
0 commit comments