@@ -162,7 +162,7 @@ def _unbatch(self, batch, *, data_dims):
162
162
def test_batched_vs_single (self , test_id , info , args_kwargs , device ):
163
163
(batched_input , * other_args ), kwargs = args_kwargs .load (device )
164
164
165
- feature_type = (
165
+ datapoint_type = (
166
166
datapoints .Image
167
167
if torchvision .prototype .transforms .utils .is_simple_tensor (batched_input )
168
168
else type (batched_input )
@@ -178,10 +178,10 @@ def test_batched_vs_single(self, test_id, info, args_kwargs, device):
178
178
# common ground.
179
179
datapoints .Mask : 2 ,
180
180
datapoints .Video : 4 ,
181
- }.get (feature_type )
181
+ }.get (datapoint_type )
182
182
if data_dims is None :
183
183
raise pytest .UsageError (
184
- f"The number of data dimensions cannot be determined for input of type { feature_type .__name__ } ."
184
+ f"The number of data dimensions cannot be determined for input of type { datapoint_type .__name__ } ."
185
185
) from None
186
186
elif batched_input .ndim <= data_dims :
187
187
pytest .skip ("Input is not batched." )
@@ -323,8 +323,8 @@ def test_logging(self, spy_on, info, args_kwargs, device):
323
323
def test_scripted_smoke (self , info , args_kwargs , device ):
324
324
dispatcher = script (info .dispatcher )
325
325
326
- (image_feature , * other_args ), kwargs = args_kwargs .load (device )
327
- image_simple_tensor = torch .Tensor (image_feature )
326
+ (image_datapoint , * other_args ), kwargs = args_kwargs .load (device )
327
+ image_simple_tensor = torch .Tensor (image_datapoint )
328
328
329
329
dispatcher (image_simple_tensor , * other_args , ** kwargs )
330
330
@@ -352,8 +352,8 @@ def test_scriptable(self, dispatcher):
352
352
353
353
@image_sample_inputs
354
354
def test_dispatch_simple_tensor (self , info , args_kwargs , spy_on ):
355
- (image_feature , * other_args ), kwargs = args_kwargs .load ()
356
- image_simple_tensor = torch .Tensor (image_feature )
355
+ (image_datapoint , * other_args ), kwargs = args_kwargs .load ()
356
+ image_simple_tensor = torch .Tensor (image_datapoint )
357
357
358
358
kernel_info = info .kernel_infos [datapoints .Image ]
359
359
spy = spy_on (kernel_info .kernel , module = info .dispatcher .__module__ , name = kernel_info .id )
@@ -367,12 +367,12 @@ def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
367
367
args_kwargs_fn = lambda info : info .sample_inputs (datapoints .Image ),
368
368
)
369
369
def test_dispatch_pil (self , info , args_kwargs , spy_on ):
370
- (image_feature , * other_args ), kwargs = args_kwargs .load ()
370
+ (image_datapoint , * other_args ), kwargs = args_kwargs .load ()
371
371
372
- if image_feature .ndim > 3 :
372
+ if image_datapoint .ndim > 3 :
373
373
pytest .skip ("Input is batched" )
374
374
375
- image_pil = F .to_image_pil (image_feature )
375
+ image_pil = F .to_image_pil (image_datapoint )
376
376
377
377
pil_kernel_info = info .pil_kernel_info
378
378
spy = spy_on (pil_kernel_info .kernel , module = info .dispatcher .__module__ , name = pil_kernel_info .id )
@@ -385,37 +385,39 @@ def test_dispatch_pil(self, info, args_kwargs, spy_on):
385
385
DISPATCHER_INFOS ,
386
386
args_kwargs_fn = lambda info : info .sample_inputs (),
387
387
)
388
- def test_dispatch_feature (self , info , args_kwargs , spy_on ):
389
- (feature , * other_args ), kwargs = args_kwargs .load ()
388
+ def test_dispatch_datapoint (self , info , args_kwargs , spy_on ):
389
+ (datapoint , * other_args ), kwargs = args_kwargs .load ()
390
390
391
391
method_name = info .id
392
- method = getattr (feature , method_name )
393
- feature_type = type (feature )
394
- spy = spy_on (method , module = feature_type .__module__ , name = f"{ feature_type .__name__ } .{ method_name } " )
392
+ method = getattr (datapoint , method_name )
393
+ datapoint_type = type (datapoint )
394
+ spy = spy_on (method , module = datapoint_type .__module__ , name = f"{ datapoint_type .__name__ } .{ method_name } " )
395
395
396
- info .dispatcher (feature , * other_args , ** kwargs )
396
+ info .dispatcher (datapoint , * other_args , ** kwargs )
397
397
398
398
spy .assert_called_once ()
399
399
400
400
@pytest .mark .parametrize (
401
- ("dispatcher_info" , "feature_type " , "kernel_info" ),
401
+ ("dispatcher_info" , "datapoint_type " , "kernel_info" ),
402
402
[
403
- pytest .param (dispatcher_info , feature_type , kernel_info , id = f"{ dispatcher_info .id } -{ feature_type .__name__ } " )
403
+ pytest .param (
404
+ dispatcher_info , datapoint_type , kernel_info , id = f"{ dispatcher_info .id } -{ datapoint_type .__name__ } "
405
+ )
404
406
for dispatcher_info in DISPATCHER_INFOS
405
- for feature_type , kernel_info in dispatcher_info .kernel_infos .items ()
407
+ for datapoint_type , kernel_info in dispatcher_info .kernel_infos .items ()
406
408
],
407
409
)
408
- def test_dispatcher_kernel_signatures_consistency (self , dispatcher_info , feature_type , kernel_info ):
410
+ def test_dispatcher_kernel_signatures_consistency (self , dispatcher_info , datapoint_type , kernel_info ):
409
411
dispatcher_signature = inspect .signature (dispatcher_info .dispatcher )
410
412
dispatcher_params = list (dispatcher_signature .parameters .values ())[1 :]
411
413
412
414
kernel_signature = inspect .signature (kernel_info .kernel )
413
415
kernel_params = list (kernel_signature .parameters .values ())[1 :]
414
416
415
- # We filter out metadata that is implicitly passed to the dispatcher through the input feature , but has to be
417
+ # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint , but has to be
416
418
# explicit passed to the kernel.
417
- feature_type_metadata = feature_type .__annotations__ .keys ()
418
- kernel_params = [param for param in kernel_params if param .name not in feature_type_metadata ]
419
+ datapoint_type_metadata = datapoint_type .__annotations__ .keys ()
420
+ kernel_params = [param for param in kernel_params if param .name not in datapoint_type_metadata ]
419
421
420
422
dispatcher_params = iter (dispatcher_params )
421
423
for dispatcher_param , kernel_param in zip (dispatcher_params , kernel_params ):
@@ -433,26 +435,26 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature
433
435
assert dispatcher_param == kernel_param
434
436
435
437
@pytest .mark .parametrize ("info" , DISPATCHER_INFOS , ids = lambda info : info .id )
436
- def test_dispatcher_feature_signatures_consistency (self , info ):
438
+ def test_dispatcher_datapoint_signatures_consistency (self , info ):
437
439
try :
438
- feature_method = getattr (datapoints ._datapoint .Datapoint , info .id )
440
+ datapoint_method = getattr (datapoints ._datapoint .Datapoint , info .id )
439
441
except AttributeError :
440
- pytest .skip ("Dispatcher doesn't support arbitrary feature dispatch." )
442
+ pytest .skip ("Dispatcher doesn't support arbitrary datapoint dispatch." )
441
443
442
444
dispatcher_signature = inspect .signature (info .dispatcher )
443
445
dispatcher_params = list (dispatcher_signature .parameters .values ())[1 :]
444
446
445
- feature_signature = inspect .signature (feature_method )
446
- feature_params = list (feature_signature .parameters .values ())[1 :]
447
+ datapoint_signature = inspect .signature (datapoint_method )
448
+ datapoint_params = list (datapoint_signature .parameters .values ())[1 :]
447
449
448
- # Because we use `from __future__ import annotations` inside the module where `features ._datapoint` is defined,
449
- # the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively
450
- # concrete dispatcher annotations.
451
- feature_annotations = get_type_hints (feature_method )
452
- for param in feature_params :
453
- param ._annotation = feature_annotations [param .name ]
450
+ # Because we use `from __future__ import annotations` inside the module where `datapoints ._datapoint` is
451
+ # defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
452
+ # natively concrete dispatcher annotations.
453
+ datapoint_annotations = get_type_hints (datapoint_method )
454
+ for param in datapoint_params :
455
+ param ._annotation = datapoint_annotations [param .name ]
454
456
455
- assert dispatcher_params == feature_params
457
+ assert dispatcher_params == datapoint_params
456
458
457
459
@pytest .mark .parametrize ("info" , DISPATCHER_INFOS , ids = lambda info : info .id )
458
460
def test_unkown_type (self , info ):
0 commit comments