@@ -452,8 +452,8 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
452
452
)
453
453
s = 5
454
454
index_1d = common_methods_invocations .index_variable (2 , s , device = device )
455
- index_2d = common_methods_invocations .index_variable ((2 , s + 1 ), s , device = device )
456
- index_3d = common_methods_invocations .index_variable ((2 , s + 1 , s + 2 ), s , device = device )
455
+ index_2d = common_methods_invocations .index_variable ((s + 1 , 2 ), s , device = device )
456
+ index_3d = common_methods_invocations .index_variable ((s + 2 , s + 1 , 2 ), s , device = device )
457
457
test_args = [
458
458
([index_1d ],),
459
459
([None , index_1d ],),
@@ -463,6 +463,24 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
463
463
([None , index_1d , None , index_1d ],),
464
464
([index_1d , None , index_1d , None ],),
465
465
([None , index_1d , index_1d , None ],),
466
+ ([index_2d ],),
467
+ ([None , index_2d ],),
468
+ ([None , None , None , index_2d ],),
469
+ ([index_2d , None ],),
470
+ ([index_2d , None , None ],),
471
+ ([None , index_2d , None , index_2d ],),
472
+ ([index_2d , None , index_2d , None ],),
473
+ ([None , index_2d , index_2d , None ],),
474
+ ([index_3d ],),
475
+ ([None , index_3d ],),
476
+ ([None , None , None , index_3d ],),
477
+ ([index_3d , None ],),
478
+ ([index_3d , None , None ],),
479
+ ([None , index_3d , None , index_3d ],),
480
+ ([index_3d , None , index_3d , None ],),
481
+ ([None , index_3d , index_3d , None ],),
482
+ # Mixed indices
483
+ ([None , index_3d , index_1d , index_2d ],),
466
484
]
467
485
468
486
for args in test_args :
0 commit comments