Skip to content

Commit aec38fc

Browse files
authored
[proto][tests] added ref tests for resize bboxes (#6879)
1 parent 1502ed9 commit aec38fc

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,31 @@ def sample_inputs_resize_video():
288288
yield ArgsKwargs(video_loader, size=[min(video_loader.shape[-2:]) + 1])
289289

290290

291+
def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=None):
292+
293+
old_height, old_width = spatial_size
294+
new_height, new_width = F._geometry._compute_resized_output_size(spatial_size, size=size, max_size=max_size)
295+
296+
affine_matrix = np.array(
297+
[
298+
[new_width / old_width, 0, 0],
299+
[0, new_height / old_height, 0],
300+
],
301+
dtype="float32",
302+
)
303+
304+
expected_bboxes = reference_affine_bounding_box_helper(
305+
bounding_box, format=bounding_box.format, affine_matrix=affine_matrix
306+
)
307+
return expected_bboxes, (new_height, new_width)
308+
309+
310+
def reference_inputs_resize_bounding_box():
311+
for bounding_box_loader in make_bounding_box_loaders(extra_dims=((), (4,))):
312+
for size in _get_resize_sizes(bounding_box_loader.spatial_size):
313+
yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size)
314+
315+
291316
KERNEL_INFOS.extend(
292317
[
293318
KernelInfo(
@@ -303,6 +328,8 @@ def sample_inputs_resize_video():
303328
KernelInfo(
304329
F.resize_bounding_box,
305330
sample_inputs_fn=sample_inputs_resize_bounding_box,
331+
reference_fn=reference_resize_bounding_box,
332+
reference_inputs_fn=reference_inputs_resize_bounding_box,
306333
test_marks=[
307334
xfail_jit_python_scalar_arg("size"),
308335
],
@@ -459,7 +486,7 @@ def transform(bbox, affine_matrix_, format_):
459486
],
460487
)
461488
out_bbox = F.convert_format_bounding_box(
462-
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
489+
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
463490
)
464491
return out_bbox.to(dtype=in_dtype)
465492

0 commit comments

Comments
 (0)