Skip to content

Commit 7637144

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] remove batch_dims from make bounding boxes and detection masks (#7855)
Summary: (Note: this ignores all push blocking failures!) Reviewed By: matteobettini Differential Revision: D48900373 fbshipit-source-id: d429799e810493fbed98f7b2980b1e1c40e777f8
1 parent 1a5837a commit 7637144

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

test/common_utils.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -406,26 +406,21 @@ def make_bounding_boxes(
406406
canvas_size=DEFAULT_SIZE,
407407
*,
408408
format=datapoints.BoundingBoxFormat.XYXY,
409-
batch_dims=(),
410409
dtype=None,
411410
device="cpu",
412411
):
413412
def sample_position(values, max_value):
414413
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
415414
# However, if we have batch_dims, we need tensors as limits.
416-
return torch.stack([torch.randint(max_value - v, ()) for v in values.flatten().tolist()]).reshape(values.shape)
415+
return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()])
417416

418417
if isinstance(format, str):
419418
format = datapoints.BoundingBoxFormat[format]
420419

421420
dtype = dtype or torch.float32
422421

423-
if any(dim == 0 for dim in batch_dims):
424-
return datapoints.BoundingBoxes(
425-
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size
426-
)
427-
428-
h, w = [torch.randint(1, c, batch_dims) for c in canvas_size]
422+
num_objects = 1
423+
h, w = [torch.randint(1, c, (num_objects,)) for c in canvas_size]
429424
y = sample_position(h, canvas_size[0])
430425
x = sample_position(w, canvas_size[1])
431426

@@ -448,11 +443,12 @@ def sample_position(values, max_value):
448443
)
449444

450445

451-
def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtype=None, device="cpu"):
446+
def make_detection_mask(size=DEFAULT_SIZE, *, dtype=None, device="cpu"):
452447
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
448+
num_objects = 1
453449
return datapoints.Mask(
454450
torch.testing.make_tensor(
455-
(*batch_dims, num_objects, *size),
451+
(num_objects, *size),
456452
low=0,
457453
high=2,
458454
dtype=dtype or torch.bool,

0 commit comments

Comments
 (0)