@@ -446,6 +446,22 @@ def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs):
446446 yield opinfo_core .SampleInput (tensor , args = (output_size , kernel_size ), kwargs = kwargs )
447447
448448
449+ def sample_inputs_index (op_info , device , dtype , requires_grad , ** kwargs ):
450+ del op_info # Unused
451+ del kwargs # Unused
452+ make_arg = functools .partial (
453+ torch_testing .make_tensor , dtype = dtype , device = device , requires_grad = requires_grad
454+ )
455+ s = 5
456+ test_args = [
457+ ([common_methods_invocations .index_variable (2 , s , device = device )],),
458+ # ([torch.tensor()],)
459+ ]
460+
461+ for args in test_args :
462+ yield opinfo_core .SampleInput (make_arg ((s , s , s )), args = args )
463+
464+
449465def sample_inputs_stft (op_info , device , dtype , requires_grad , ** kwargs ):
450466 del op_info
451467 del kwargs
@@ -581,9 +597,17 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
581597 skips = (),
582598 supports_out = False ,
583599 ),
600+ opinfo_core .OpInfo (
601+ "aten.index.Tensor" ,
602+ dtypes = common_dtype .all_types_and_complex_and (
603+ torch .bool , torch .float16 , torch .bfloat16 , torch .chalf
604+ ),
605+ aten_name = "index" ,
606+ op = torch .ops .aten .index .Tensor ,
607+ sample_inputs_func = sample_inputs_index ,
608+ ),
584609 opinfo_core .OpInfo (
585610 "layer_norm" ,
586- aliases = ("layer_norm" ,),
587611 aten_name = "layer_norm" ,
588612 dtypes = common_dtype .floating_and_complex_types_and (torch .int64 , torch .bfloat16 ),
589613 sample_inputs_func = sample_inputs_layer_norm ,
0 commit comments