@@ -460,6 +460,7 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
460
460
([None , None , None , index_1d ],),
461
461
([index_1d , None ],),
462
462
([index_1d , None , None ],),
463
+ # Extra index
463
464
([None , index_1d , None , index_1d ],),
464
465
([index_1d , None , index_1d , None ],),
465
466
([None , index_1d , index_1d , None ],),
@@ -468,6 +469,7 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
468
469
([None , None , None , index_2d ],),
469
470
([index_2d , None ],),
470
471
([index_2d , None , None ],),
472
+ # Extra index
471
473
([None , index_2d , None , index_2d ],),
472
474
([index_2d , None , index_2d , None ],),
473
475
([None , index_2d , index_2d , None ],),
@@ -476,11 +478,15 @@ def sample_inputs_index(op_info, device, dtype, requires_grad, **kwargs):
476
478
([None , None , None , index_3d ],),
477
479
([index_3d , None ],),
478
480
([index_3d , None , None ],),
481
+ # Extra index
479
482
([None , index_3d , None , index_3d ],),
480
483
([index_3d , None , index_3d , None ],),
481
484
([None , index_3d , index_3d , None ],),
482
485
# Mixed indices
483
486
([None , index_3d , index_1d , index_2d ],),
487
+ # All indices are not None
488
+ ([index_2d , index_3d , index_1d ],),
489
+ ([index_2d , index_3d , index_1d , index_2d ],),
484
490
]
485
491
486
492
for args in test_args :
0 commit comments