15
15
from datasets_utils import combinations_grid
16
16
from torch .nn .functional import one_hot
17
17
from torch .testing ._comparison import assert_equal as _assert_equal , BooleanPair , NonePair , NumberPair , TensorLikePair
18
- from torchvision .prototype import features
18
+ from torchvision .prototype import datapoints
19
19
from torchvision .prototype .transforms .functional import convert_dtype_image_tensor , to_image_tensor
20
20
from torchvision .transforms .functional_tensor import _max_value as get_max_value
21
21
@@ -238,7 +238,7 @@ def load(self, device):
238
238
239
239
@dataclasses .dataclass
240
240
class ImageLoader (TensorLoader ):
241
- color_space : features .ColorSpace
241
+ color_space : datapoints .ColorSpace
242
242
spatial_size : Tuple [int , int ] = dataclasses .field (init = False )
243
243
num_channels : int = dataclasses .field (init = False )
244
244
@@ -248,10 +248,10 @@ def __post_init__(self):
248
248
249
249
250
250
NUM_CHANNELS_MAP = {
251
- features .ColorSpace .GRAY : 1 ,
252
- features .ColorSpace .GRAY_ALPHA : 2 ,
253
- features .ColorSpace .RGB : 3 ,
254
- features .ColorSpace .RGB_ALPHA : 4 ,
251
+ datapoints .ColorSpace .GRAY : 1 ,
252
+ datapoints .ColorSpace .GRAY_ALPHA : 2 ,
253
+ datapoints .ColorSpace .RGB : 3 ,
254
+ datapoints .ColorSpace .RGB_ALPHA : 4 ,
255
255
}
256
256
257
257
@@ -265,7 +265,7 @@ def get_num_channels(color_space):
265
265
def make_image_loader (
266
266
size = "random" ,
267
267
* ,
268
- color_space = features .ColorSpace .RGB ,
268
+ color_space = datapoints .ColorSpace .RGB ,
269
269
extra_dims = (),
270
270
dtype = torch .float32 ,
271
271
constant_alpha = True ,
@@ -276,9 +276,9 @@ def make_image_loader(
276
276
def fn (shape , dtype , device ):
277
277
max_value = get_max_value (dtype )
278
278
data = torch .testing .make_tensor (shape , low = 0 , high = max_value , dtype = dtype , device = device )
279
- if color_space in {features .ColorSpace .GRAY_ALPHA , features .ColorSpace .RGB_ALPHA } and constant_alpha :
279
+ if color_space in {datapoints .ColorSpace .GRAY_ALPHA , datapoints .ColorSpace .RGB_ALPHA } and constant_alpha :
280
280
data [..., - 1 , :, :] = max_value
281
- return features .Image (data , color_space = color_space )
281
+ return datapoints .Image (data , color_space = color_space )
282
282
283
283
return ImageLoader (fn , shape = (* extra_dims , num_channels , * size ), dtype = dtype , color_space = color_space )
284
284
@@ -290,10 +290,10 @@ def make_image_loaders(
290
290
* ,
291
291
sizes = DEFAULT_SPATIAL_SIZES ,
292
292
color_spaces = (
293
- features .ColorSpace .GRAY ,
294
- features .ColorSpace .GRAY_ALPHA ,
295
- features .ColorSpace .RGB ,
296
- features .ColorSpace .RGB_ALPHA ,
293
+ datapoints .ColorSpace .GRAY ,
294
+ datapoints .ColorSpace .GRAY_ALPHA ,
295
+ datapoints .ColorSpace .RGB ,
296
+ datapoints .ColorSpace .RGB_ALPHA ,
297
297
),
298
298
extra_dims = DEFAULT_EXTRA_DIMS ,
299
299
dtypes = (torch .float32 , torch .uint8 ),
@@ -306,7 +306,7 @@ def make_image_loaders(
306
306
make_images = from_loaders (make_image_loaders )
307
307
308
308
309
- def make_image_loader_for_interpolation (size = "random" , * , color_space = features .ColorSpace .RGB , dtype = torch .uint8 ):
309
+ def make_image_loader_for_interpolation (size = "random" , * , color_space = datapoints .ColorSpace .RGB , dtype = torch .uint8 ):
310
310
size = _parse_spatial_size (size )
311
311
num_channels = get_num_channels (color_space )
312
312
@@ -318,24 +318,24 @@ def fn(shape, dtype, device):
318
318
.resize ((width , height ))
319
319
.convert (
320
320
{
321
- features .ColorSpace .GRAY : "L" ,
322
- features .ColorSpace .GRAY_ALPHA : "LA" ,
323
- features .ColorSpace .RGB : "RGB" ,
324
- features .ColorSpace .RGB_ALPHA : "RGBA" ,
321
+ datapoints .ColorSpace .GRAY : "L" ,
322
+ datapoints .ColorSpace .GRAY_ALPHA : "LA" ,
323
+ datapoints .ColorSpace .RGB : "RGB" ,
324
+ datapoints .ColorSpace .RGB_ALPHA : "RGBA" ,
325
325
}[color_space ]
326
326
)
327
327
)
328
328
329
329
image_tensor = convert_dtype_image_tensor (to_image_tensor (image_pil ).to (device = device ), dtype = dtype )
330
330
331
- return features .Image (image_tensor , color_space = color_space )
331
+ return datapoints .Image (image_tensor , color_space = color_space )
332
332
333
333
return ImageLoader (fn , shape = (num_channels , * size ), dtype = dtype , color_space = color_space )
334
334
335
335
336
336
def make_image_loaders_for_interpolation (
337
337
sizes = ((233 , 147 ),),
338
- color_spaces = (features .ColorSpace .RGB ,),
338
+ color_spaces = (datapoints .ColorSpace .RGB ,),
339
339
dtypes = (torch .uint8 ,),
340
340
):
341
341
for params in combinations_grid (size = sizes , color_space = color_spaces , dtype = dtypes ):
@@ -344,7 +344,7 @@ def make_image_loaders_for_interpolation(
344
344
345
345
@dataclasses .dataclass
346
346
class BoundingBoxLoader (TensorLoader ):
347
- format : features .BoundingBoxFormat
347
+ format : datapoints .BoundingBoxFormat
348
348
spatial_size : Tuple [int , int ]
349
349
350
350
@@ -362,11 +362,11 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
362
362
363
363
def make_bounding_box_loader (* , extra_dims = (), format , spatial_size = "random" , dtype = torch .float32 ):
364
364
if isinstance (format , str ):
365
- format = features .BoundingBoxFormat [format ]
365
+ format = datapoints .BoundingBoxFormat [format ]
366
366
if format not in {
367
- features .BoundingBoxFormat .XYXY ,
368
- features .BoundingBoxFormat .XYWH ,
369
- features .BoundingBoxFormat .CXCYWH ,
367
+ datapoints .BoundingBoxFormat .XYXY ,
368
+ datapoints .BoundingBoxFormat .XYWH ,
369
+ datapoints .BoundingBoxFormat .CXCYWH ,
370
370
}:
371
371
raise pytest .UsageError (f"Can't make bounding box in format { format } " )
372
372
@@ -378,19 +378,19 @@ def fn(shape, dtype, device):
378
378
raise pytest .UsageError ()
379
379
380
380
if any (dim == 0 for dim in extra_dims ):
381
- return features .BoundingBox (
381
+ return datapoints .BoundingBox (
382
382
torch .empty (* extra_dims , 4 , dtype = dtype , device = device ), format = format , spatial_size = spatial_size
383
383
)
384
384
385
385
height , width = spatial_size
386
386
387
- if format == features .BoundingBoxFormat .XYXY :
387
+ if format == datapoints .BoundingBoxFormat .XYXY :
388
388
x1 = torch .randint (0 , width // 2 , extra_dims )
389
389
y1 = torch .randint (0 , height // 2 , extra_dims )
390
390
x2 = randint_with_tensor_bounds (x1 + 1 , width - x1 ) + x1
391
391
y2 = randint_with_tensor_bounds (y1 + 1 , height - y1 ) + y1
392
392
parts = (x1 , y1 , x2 , y2 )
393
- elif format == features .BoundingBoxFormat .XYWH :
393
+ elif format == datapoints .BoundingBoxFormat .XYWH :
394
394
x = torch .randint (0 , width // 2 , extra_dims )
395
395
y = torch .randint (0 , height // 2 , extra_dims )
396
396
w = randint_with_tensor_bounds (1 , width - x )
@@ -403,7 +403,7 @@ def fn(shape, dtype, device):
403
403
h = randint_with_tensor_bounds (1 , torch .minimum (cy , height - cy ) + 1 )
404
404
parts = (cx , cy , w , h )
405
405
406
- return features .BoundingBox (
406
+ return datapoints .BoundingBox (
407
407
torch .stack (parts , dim = - 1 ).to (dtype = dtype , device = device ), format = format , spatial_size = spatial_size
408
408
)
409
409
@@ -416,7 +416,7 @@ def fn(shape, dtype, device):
416
416
def make_bounding_box_loaders (
417
417
* ,
418
418
extra_dims = DEFAULT_EXTRA_DIMS ,
419
- formats = tuple (features .BoundingBoxFormat ),
419
+ formats = tuple (datapoints .BoundingBoxFormat ),
420
420
spatial_size = "random" ,
421
421
dtypes = (torch .float32 , torch .int64 ),
422
422
):
@@ -456,7 +456,7 @@ def fn(shape, dtype, device):
456
456
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
457
457
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
458
458
data = torch .testing .make_tensor (shape , low = 0 , high = num_categories , dtype = torch .int64 , device = device ).to (dtype )
459
- return features .Label (data , categories = categories )
459
+ return datapoints .Label (data , categories = categories )
460
460
461
461
return LabelLoader (fn , shape = extra_dims , dtype = dtype , categories = categories )
462
462
@@ -480,7 +480,7 @@ def fn(shape, dtype, device):
480
480
# since `one_hot` only supports int64
481
481
label = make_label_loader (extra_dims = extra_dims , categories = num_categories , dtype = torch .int64 ).load (device )
482
482
data = one_hot (label , num_classes = num_categories ).to (dtype )
483
- return features .OneHotLabel (data , categories = categories )
483
+ return datapoints .OneHotLabel (data , categories = categories )
484
484
485
485
return OneHotLabelLoader (fn , shape = (* extra_dims , num_categories ), dtype = dtype , categories = categories )
486
486
@@ -509,7 +509,7 @@ def make_detection_mask_loader(size="random", *, num_objects="random", extra_dim
509
509
510
510
def fn (shape , dtype , device ):
511
511
data = torch .testing .make_tensor (shape , low = 0 , high = 2 , dtype = dtype , device = device )
512
- return features .Mask (data )
512
+ return datapoints .Mask (data )
513
513
514
514
return MaskLoader (fn , shape = (* extra_dims , num_objects , * size ), dtype = dtype )
515
515
@@ -537,7 +537,7 @@ def make_segmentation_mask_loader(size="random", *, num_categories="random", ext
537
537
538
538
def fn (shape , dtype , device ):
539
539
data = torch .testing .make_tensor (shape , low = 0 , high = num_categories , dtype = dtype , device = device )
540
- return features .Mask (data )
540
+ return datapoints .Mask (data )
541
541
542
542
return MaskLoader (fn , shape = (* extra_dims , * size ), dtype = dtype )
543
543
@@ -583,7 +583,7 @@ class VideoLoader(ImageLoader):
583
583
def make_video_loader (
584
584
size = "random" ,
585
585
* ,
586
- color_space = features .ColorSpace .RGB ,
586
+ color_space = datapoints .ColorSpace .RGB ,
587
587
num_frames = "random" ,
588
588
extra_dims = (),
589
589
dtype = torch .uint8 ,
@@ -593,7 +593,7 @@ def make_video_loader(
593
593
594
594
def fn (shape , dtype , device ):
595
595
video = make_image (size = shape [- 2 :], color_space = color_space , extra_dims = shape [:- 3 ], dtype = dtype , device = device )
596
- return features .Video (video , color_space = color_space )
596
+ return datapoints .Video (video , color_space = color_space )
597
597
598
598
return VideoLoader (
599
599
fn , shape = (* extra_dims , num_frames , get_num_channels (color_space ), * size ), dtype = dtype , color_space = color_space
@@ -607,8 +607,8 @@ def make_video_loaders(
607
607
* ,
608
608
sizes = DEFAULT_SPATIAL_SIZES ,
609
609
color_spaces = (
610
- features .ColorSpace .GRAY ,
611
- features .ColorSpace .RGB ,
610
+ datapoints .ColorSpace .GRAY ,
611
+ datapoints .ColorSpace .RGB ,
612
612
),
613
613
num_frames = (1 , 0 , "random" ),
614
614
extra_dims = DEFAULT_EXTRA_DIMS ,
0 commit comments