66
77import pytest
88import torch
9- from common_utils import assert_equal
9+ from common_utils import assert_equal , cpu_and_gpu
1010from test_prototype_transforms_functional import (
1111 make_bounding_box ,
1212 make_bounding_boxes ,
1515 make_one_hot_labels ,
1616 make_segmentation_mask ,
1717)
18+ from torchvision .ops .boxes import box_iou
1819from torchvision .prototype import features , transforms
1920from torchvision .transforms .functional import InterpolationMode , pil_to_tensor , to_pil_image
2021
@@ -793,7 +794,7 @@ def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker):
793794 if p > 0.0 :
794795 fn .assert_called_once_with (inpt , ** kwargs )
795796 else :
796- fn .call_count == 0
797+ assert fn .call_count == 0
797798
798799
799800class TestRandomPerspective :
@@ -1014,7 +1015,7 @@ def test__transform(self, p, inpt_type, mocker):
10141015 if p > 0.0 :
10151016 fn .assert_called_once_with (erase_image_tensor_inpt , ** params )
10161017 else :
1017- fn .call_count == 0
1018+ assert fn .call_count == 0
10181019
10191020
10201021class TestTransform :
@@ -1050,7 +1051,7 @@ def test__transform(self, inpt_type, mocker):
10501051 transform = transforms .ToImageTensor ()
10511052 transform (inpt )
10521053 if inpt_type in (features .BoundingBox , str , int ):
1053- fn .call_count == 0
1054+ assert fn .call_count == 0
10541055 else :
10551056 fn .assert_called_once_with (inpt , copy = transform .copy )
10561057
@@ -1067,7 +1068,7 @@ def test__transform(self, inpt_type, mocker):
10671068 transform = transforms .ToImagePIL ()
10681069 transform (inpt )
10691070 if inpt_type in (features .BoundingBox , str , int ):
1070- fn .call_count == 0
1071+ assert fn .call_count == 0
10711072 else :
10721073 fn .assert_called_once_with (inpt , copy = transform .copy )
10731074
@@ -1085,7 +1086,7 @@ def test__transform(self, inpt_type, mocker):
10851086 transform = transforms .ToPILImage ()
10861087 transform (inpt )
10871088 if inpt_type in (PIL .Image .Image , features .BoundingBox , str , int ):
1088- fn .call_count == 0
1089+ assert fn .call_count == 0
10891090 else :
10901091 fn .assert_called_once_with (inpt , mode = transform .mode )
10911092
@@ -1103,7 +1104,7 @@ def test__transform(self, inpt_type, mocker):
11031104 transform = transforms .ToTensor ()
11041105 transform (inpt )
11051106 if inpt_type in (features .Image , torch .Tensor , features .BoundingBox , str , int ):
1106- fn .call_count == 0
1107+ assert fn .call_count == 0
11071108 else :
11081109 fn .assert_called_once_with (inpt )
11091110
@@ -1127,6 +1128,124 @@ def test_ctor(self, trfms):
11271128 assert isinstance (output , torch .Tensor )
11281129
11291130
1131+ class TestRandomIoUCrop :
1132+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1133+ @pytest .mark .parametrize ("options" , [[0.5 , 0.9 ], [2.0 ]])
1134+ def test__get_params (self , device , options , mocker ):
1135+ image = mocker .MagicMock (spec = features .Image )
1136+ image .num_channels = 3
1137+ image .image_size = (24 , 32 )
1138+ bboxes = features .BoundingBox (
1139+ torch .tensor ([[1 , 1 , 10 , 10 ], [20 , 20 , 23 , 23 ], [1 , 20 , 10 , 23 ], [20 , 1 , 23 , 10 ]]),
1140+ format = "XYXY" ,
1141+ image_size = image .image_size ,
1142+ device = device ,
1143+ )
1144+ sample = [image , bboxes ]
1145+
1146+ transform = transforms .RandomIoUCrop (sampler_options = options )
1147+
1148+ n_samples = 5
1149+ for _ in range (n_samples ):
1150+
1151+ params = transform ._get_params (sample )
1152+
1153+ if options == [2.0 ]:
1154+ assert len (params ) == 0
1155+ return
1156+
1157+ assert len (params ["is_within_crop_area" ]) > 0
1158+ assert params ["is_within_crop_area" ].dtype == torch .bool
1159+
1160+ orig_h = image .image_size [0 ]
1161+ orig_w = image .image_size [1 ]
1162+ assert int (transform .min_scale * orig_h ) <= params ["height" ] <= int (transform .max_scale * orig_h )
1163+ assert int (transform .min_scale * orig_w ) <= params ["width" ] <= int (transform .max_scale * orig_w )
1164+
1165+ left , top = params ["left" ], params ["top" ]
1166+ new_h , new_w = params ["height" ], params ["width" ]
1167+ ious = box_iou (
1168+ bboxes ,
1169+ torch .tensor ([[left , top , left + new_w , top + new_h ]], dtype = bboxes .dtype , device = bboxes .device ),
1170+ )
1171+ assert ious .max () >= options [0 ] or ious .max () >= options [1 ], f"{ ious } vs { options } "
1172+
1173+ def test__transform_empty_params (self , mocker ):
1174+ transform = transforms .RandomIoUCrop (sampler_options = [2.0 ])
1175+ image = features .Image (torch .rand (1 , 3 , 4 , 4 ))
1176+ bboxes = features .BoundingBox (torch .tensor ([[1 , 1 , 2 , 2 ]]), format = "XYXY" , image_size = (4 , 4 ))
1177+ label = features .Label (torch .tensor ([1 ]))
1178+ sample = [image , bboxes , label ]
1179+ # Let's mock transform._get_params to control the output:
1180+ transform ._get_params = mocker .MagicMock (return_value = {})
1181+ output = transform (sample )
1182+ torch .testing .assert_close (output , sample )
1183+
1184+ def test_forward_assertion (self ):
1185+ transform = transforms .RandomIoUCrop ()
1186+ with pytest .raises (
1187+ TypeError ,
1188+ match = "requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels" ,
1189+ ):
1190+ transform (torch .tensor (0 ))
1191+
1192+ def test__transform (self , mocker ):
1193+ transform = transforms .RandomIoUCrop ()
1194+
1195+ image = features .Image (torch .rand (3 , 32 , 24 ))
1196+ bboxes = make_bounding_box (format = "XYXY" , image_size = (32 , 24 ), extra_dims = (6 ,))
1197+ label = features .Label (torch .randint (0 , 10 , size = (6 ,)))
1198+ ohe_label = features .OneHotLabel (torch .zeros (6 , 10 ).scatter_ (1 , label .unsqueeze (1 ), 1 ))
1199+ masks = make_segmentation_mask ((32 , 24 ))
1200+ ohe_masks = features .SegmentationMask (torch .randint (0 , 2 , size = (6 , 32 , 24 )))
1201+ sample = [image , bboxes , label , ohe_label , masks , ohe_masks ]
1202+
1203+ fn = mocker .patch ("torchvision.prototype.transforms.functional.crop" , side_effect = lambda x , ** params : x )
1204+ is_within_crop_area = torch .tensor ([0 , 1 , 0 , 1 , 0 , 1 ], dtype = torch .bool )
1205+
1206+ params = dict (top = 1 , left = 2 , height = 12 , width = 12 , is_within_crop_area = is_within_crop_area )
1207+ transform ._get_params = mocker .MagicMock (return_value = params )
1208+ output = transform (sample )
1209+
1210+ assert fn .call_count == 4
1211+
1212+ expected_calls = [
1213+ mocker .call (image , top = params ["top" ], left = params ["left" ], height = params ["height" ], width = params ["width" ]),
1214+ mocker .call (bboxes , top = params ["top" ], left = params ["left" ], height = params ["height" ], width = params ["width" ]),
1215+ mocker .call (masks , top = params ["top" ], left = params ["left" ], height = params ["height" ], width = params ["width" ]),
1216+ mocker .call (
1217+ ohe_masks , top = params ["top" ], left = params ["left" ], height = params ["height" ], width = params ["width" ]
1218+ ),
1219+ ]
1220+
1221+ fn .assert_has_calls (expected_calls )
1222+
1223+ expected_within_targets = sum (is_within_crop_area )
1224+
1225+ # check number of bboxes vs number of labels:
1226+ output_bboxes = output [1 ]
1227+ assert isinstance (output_bboxes , features .BoundingBox )
1228+ assert len (output_bboxes ) == expected_within_targets
1229+
1230+ # check labels
1231+ output_label = output [2 ]
1232+ assert isinstance (output_label , features .Label )
1233+ assert len (output_label ) == expected_within_targets
1234+ torch .testing .assert_close (output_label , label [is_within_crop_area ])
1235+
1236+ output_ohe_label = output [3 ]
1237+ assert isinstance (output_ohe_label , features .OneHotLabel )
1238+ torch .testing .assert_close (output_ohe_label , ohe_label [is_within_crop_area ])
1239+
1240+ output_masks = output [4 ]
1241+ assert isinstance (output_masks , features .SegmentationMask )
1242+ assert output_masks .shape [:- 2 ] == masks .shape [:- 2 ]
1243+
1244+ output_ohe_masks = output [5 ]
1245+ assert isinstance (output_ohe_masks , features .SegmentationMask )
1246+ assert len (output_ohe_masks ) == expected_within_targets
1247+
1248+
11301249class TestScaleJitter :
11311250 def test__get_params (self , mocker ):
11321251 image_size = (24 , 32 )
0 commit comments