6
6
7
7
import pytest
8
8
import torch
9
- from common_utils import assert_equal
9
+ from common_utils import assert_equal , cpu_and_gpu
10
10
from test_prototype_transforms_functional import (
11
11
make_bounding_box ,
12
12
make_bounding_boxes ,
15
15
make_one_hot_labels ,
16
16
make_segmentation_mask ,
17
17
)
18
+ from torchvision .ops .boxes import box_iou
18
19
from torchvision .prototype import features , transforms
19
20
from torchvision .transforms .functional import InterpolationMode , pil_to_tensor , to_pil_image
20
21
@@ -793,7 +794,7 @@ def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker):
793
794
if p > 0.0 :
794
795
fn .assert_called_once_with (inpt , ** kwargs )
795
796
else :
796
- fn .call_count == 0
797
+ assert fn .call_count == 0
797
798
798
799
799
800
class TestRandomPerspective :
@@ -1014,7 +1015,7 @@ def test__transform(self, p, inpt_type, mocker):
1014
1015
if p > 0.0 :
1015
1016
fn .assert_called_once_with (erase_image_tensor_inpt , ** params )
1016
1017
else :
1017
- fn .call_count == 0
1018
+ assert fn .call_count == 0
1018
1019
1019
1020
1020
1021
class TestTransform :
@@ -1050,7 +1051,7 @@ def test__transform(self, inpt_type, mocker):
1050
1051
transform = transforms .ToImageTensor ()
1051
1052
transform (inpt )
1052
1053
if inpt_type in (features .BoundingBox , str , int ):
1053
- fn .call_count == 0
1054
+ assert fn .call_count == 0
1054
1055
else :
1055
1056
fn .assert_called_once_with (inpt , copy = transform .copy )
1056
1057
@@ -1067,7 +1068,7 @@ def test__transform(self, inpt_type, mocker):
1067
1068
transform = transforms .ToImagePIL ()
1068
1069
transform (inpt )
1069
1070
if inpt_type in (features .BoundingBox , str , int ):
1070
- fn .call_count == 0
1071
+ assert fn .call_count == 0
1071
1072
else :
1072
1073
fn .assert_called_once_with (inpt , copy = transform .copy )
1073
1074
@@ -1085,7 +1086,7 @@ def test__transform(self, inpt_type, mocker):
1085
1086
transform = transforms .ToPILImage ()
1086
1087
transform (inpt )
1087
1088
if inpt_type in (PIL .Image .Image , features .BoundingBox , str , int ):
1088
- fn .call_count == 0
1089
+ assert fn .call_count == 0
1089
1090
else :
1090
1091
fn .assert_called_once_with (inpt , mode = transform .mode )
1091
1092
@@ -1103,7 +1104,7 @@ def test__transform(self, inpt_type, mocker):
1103
1104
transform = transforms .ToTensor ()
1104
1105
transform (inpt )
1105
1106
if inpt_type in (features .Image , torch .Tensor , features .BoundingBox , str , int ):
1106
- fn .call_count == 0
1107
+ assert fn .call_count == 0
1107
1108
else :
1108
1109
fn .assert_called_once_with (inpt )
1109
1110
@@ -1127,6 +1128,124 @@ def test_ctor(self, trfms):
1127
1128
assert isinstance (output , torch .Tensor )
1128
1129
1129
1130
1131
+ class TestRandomIoUCrop :
1132
+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1133
+ @pytest .mark .parametrize ("options" , [[0.5 , 0.9 ], [2.0 ]])
1134
+ def test__get_params (self , device , options , mocker ):
1135
+ image = mocker .MagicMock (spec = features .Image )
1136
+ image .num_channels = 3
1137
+ image .image_size = (24 , 32 )
1138
+ bboxes = features .BoundingBox (
1139
+ torch .tensor ([[1 , 1 , 10 , 10 ], [20 , 20 , 23 , 23 ], [1 , 20 , 10 , 23 ], [20 , 1 , 23 , 10 ]]),
1140
+ format = "XYXY" ,
1141
+ image_size = image .image_size ,
1142
+ device = device ,
1143
+ )
1144
+ sample = [image , bboxes ]
1145
+
1146
+ transform = transforms .RandomIoUCrop (sampler_options = options )
1147
+
1148
+ n_samples = 5
1149
+ for _ in range (n_samples ):
1150
+
1151
+ params = transform ._get_params (sample )
1152
+
1153
+ if options == [2.0 ]:
1154
+ assert len (params ) == 0
1155
+ return
1156
+
1157
+ assert len (params ["is_within_crop_area" ]) > 0
1158
+ assert params ["is_within_crop_area" ].dtype == torch .bool
1159
+
1160
+ orig_h = image .image_size [0 ]
1161
+ orig_w = image .image_size [1 ]
1162
+ assert int (transform .min_scale * orig_h ) <= params ["height" ] <= int (transform .max_scale * orig_h )
1163
+ assert int (transform .min_scale * orig_w ) <= params ["width" ] <= int (transform .max_scale * orig_w )
1164
+
1165
+ left , top = params ["left" ], params ["top" ]
1166
+ new_h , new_w = params ["height" ], params ["width" ]
1167
+ ious = box_iou (
1168
+ bboxes ,
1169
+ torch .tensor ([[left , top , left + new_w , top + new_h ]], dtype = bboxes .dtype , device = bboxes .device ),
1170
+ )
1171
+ assert ious .max () >= options [0 ] or ious .max () >= options [1 ], f"{ ious } vs { options } "
1172
+
1173
+ def test__transform_empty_params (self , mocker ):
1174
+ transform = transforms .RandomIoUCrop (sampler_options = [2.0 ])
1175
+ image = features .Image (torch .rand (1 , 3 , 4 , 4 ))
1176
+ bboxes = features .BoundingBox (torch .tensor ([[1 , 1 , 2 , 2 ]]), format = "XYXY" , image_size = (4 , 4 ))
1177
+ label = features .Label (torch .tensor ([1 ]))
1178
+ sample = [image , bboxes , label ]
1179
+ # Let's mock transform._get_params to control the output:
1180
+ transform ._get_params = mocker .MagicMock (return_value = {})
1181
+ output = transform (sample )
1182
+ torch .testing .assert_close (output , sample )
1183
+
1184
+ def test_forward_assertion (self ):
1185
+ transform = transforms .RandomIoUCrop ()
1186
+ with pytest .raises (
1187
+ TypeError ,
1188
+ match = "requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels" ,
1189
+ ):
1190
+ transform (torch .tensor (0 ))
1191
+
1192
+ def test__transform (self , mocker ):
1193
+ transform = transforms .RandomIoUCrop ()
1194
+
1195
+ image = features .Image (torch .rand (3 , 32 , 24 ))
1196
+ bboxes = make_bounding_box (format = "XYXY" , image_size = (32 , 24 ), extra_dims = (6 ,))
1197
+ label = features .Label (torch .randint (0 , 10 , size = (6 ,)))
1198
+ ohe_label = features .OneHotLabel (torch .zeros (6 , 10 ).scatter_ (1 , label .unsqueeze (1 ), 1 ))
1199
+ masks = make_segmentation_mask ((32 , 24 ))
1200
+ ohe_masks = features .SegmentationMask (torch .randint (0 , 2 , size = (6 , 32 , 24 )))
1201
+ sample = [image , bboxes , label , ohe_label , masks , ohe_masks ]
1202
+
1203
+ fn = mocker .patch ("torchvision.prototype.transforms.functional.crop" , side_effect = lambda x , ** params : x )
1204
+ is_within_crop_area = torch .tensor ([0 , 1 , 0 , 1 , 0 , 1 ], dtype = torch .bool )
1205
+
1206
+ params = dict (top = 1 , left = 2 , height = 12 , width = 12 , is_within_crop_area = is_within_crop_area )
1207
+ transform ._get_params = mocker .MagicMock (return_value = params )
1208
+ output = transform (sample )
1209
+
1210
+ assert fn .call_count == 4
1211
+
1212
+ expected_calls = [
1213
+ mocker .call (image , top = params ["top" ], left = params ["left" ], height = params ["height" ], width = params ["width" ]),
1214
+ mocker .call (bboxes , top = params ["top" ], left = params ["left" ], height = params ["height" ], width = params ["width" ]),
1215
+ mocker .call (masks , top = params ["top" ], left = params ["left" ], height = params ["height" ], width = params ["width" ]),
1216
+ mocker .call (
1217
+ ohe_masks , top = params ["top" ], left = params ["left" ], height = params ["height" ], width = params ["width" ]
1218
+ ),
1219
+ ]
1220
+
1221
+ fn .assert_has_calls (expected_calls )
1222
+
1223
+ expected_within_targets = sum (is_within_crop_area )
1224
+
1225
+ # check number of bboxes vs number of labels:
1226
+ output_bboxes = output [1 ]
1227
+ assert isinstance (output_bboxes , features .BoundingBox )
1228
+ assert len (output_bboxes ) == expected_within_targets
1229
+
1230
+ # check labels
1231
+ output_label = output [2 ]
1232
+ assert isinstance (output_label , features .Label )
1233
+ assert len (output_label ) == expected_within_targets
1234
+ torch .testing .assert_close (output_label , label [is_within_crop_area ])
1235
+
1236
+ output_ohe_label = output [3 ]
1237
+ assert isinstance (output_ohe_label , features .OneHotLabel )
1238
+ torch .testing .assert_close (output_ohe_label , ohe_label [is_within_crop_area ])
1239
+
1240
+ output_masks = output [4 ]
1241
+ assert isinstance (output_masks , features .SegmentationMask )
1242
+ assert output_masks .shape [:- 2 ] == masks .shape [:- 2 ]
1243
+
1244
+ output_ohe_masks = output [5 ]
1245
+ assert isinstance (output_ohe_masks , features .SegmentationMask )
1246
+ assert len (output_ohe_masks ) == expected_within_targets
1247
+
1248
+
1130
1249
class TestScaleJitter :
1131
1250
def test__get_params (self , mocker ):
1132
1251
image_size = (24 , 32 )
0 commit comments