@@ -406,26 +406,21 @@ def make_bounding_boxes(
406
406
canvas_size = DEFAULT_SIZE ,
407
407
* ,
408
408
format = datapoints .BoundingBoxFormat .XYXY ,
409
- batch_dims = (),
410
409
dtype = None ,
411
410
device = "cpu" ,
412
411
):
413
412
def sample_position (values , max_value ):
414
413
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
415
414
# However, if we have batch_dims, we need tensors as limits.
416
- return torch .stack ([torch .randint (max_value - v , ()) for v in values .flatten (). tolist ()]). reshape ( values . shape )
415
+ return torch .stack ([torch .randint (max_value - v , ()) for v in values .tolist ()])
417
416
418
417
if isinstance (format , str ):
419
418
format = datapoints .BoundingBoxFormat [format ]
420
419
421
420
dtype = dtype or torch .float32
422
421
423
- if any (dim == 0 for dim in batch_dims ):
424
- return datapoints .BoundingBoxes (
425
- torch .empty (* batch_dims , 4 , dtype = dtype , device = device ), format = format , canvas_size = canvas_size
426
- )
427
-
428
- h , w = [torch .randint (1 , c , batch_dims ) for c in canvas_size ]
422
+ num_objects = 1
423
+ h , w = [torch .randint (1 , c , (num_objects ,)) for c in canvas_size ]
429
424
y = sample_position (h , canvas_size [0 ])
430
425
x = sample_position (w , canvas_size [1 ])
431
426
@@ -448,11 +443,12 @@ def sample_position(values, max_value):
448
443
)
449
444
450
445
451
- def make_detection_mask (size = DEFAULT_SIZE , * , num_objects = 5 , batch_dims = (), dtype = None , device = "cpu" ):
446
+ def make_detection_mask (size = DEFAULT_SIZE , * , dtype = None , device = "cpu" ):
452
447
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
448
+ num_objects = 1
453
449
return datapoints .Mask (
454
450
torch .testing .make_tensor (
455
- (* batch_dims , num_objects , * size ),
451
+ (num_objects , * size ),
456
452
low = 0 ,
457
453
high = 2 ,
458
454
dtype = dtype or torch .bool ,
0 commit comments